Upload 5 files
Browse files- .gitattributes +1 -0
- app.py +145 -0
- checkpoint_72001/_METADATA +0 -0
- checkpoint_72001/checkpoint +3 -0
- codi/controlnet_flax.py +396 -0
- codi/pipeline_flax_controlnet.py +610 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
checkpoint_72001/checkpoint filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import jax
|
3 |
+
import numpy as np
|
4 |
+
import jax.numpy as jnp
|
5 |
+
from flax.training import checkpoints
|
6 |
+
from diffusers import FlaxControlNetModel, FlaxUNet2DConditionModel, FlaxAutoencoderKL, FlaxDDIMScheduler
|
7 |
+
from codi.controlnet_flax import FlaxControlNetModel
|
8 |
+
from codi.pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline
|
9 |
+
from transformers import CLIPTokenizer, FlaxCLIPTextModel
|
10 |
+
from flax.training.common_utils import shard
|
11 |
+
from flax.jax_utils import replicate
|
12 |
+
|
13 |
+
|
14 |
+
MODEL_NAME = "CompVis/stable-diffusion-v1-4"
|
15 |
+
|
16 |
+
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
|
17 |
+
MODEL_NAME,
|
18 |
+
subfolder="unet",
|
19 |
+
revision="flax",
|
20 |
+
dtype=jnp.float32,
|
21 |
+
)
|
22 |
+
vae, vae_params = FlaxAutoencoderKL.from_pretrained(
|
23 |
+
MODEL_NAME,
|
24 |
+
subfolder="vae",
|
25 |
+
revision="flax",
|
26 |
+
dtype=jnp.float32,
|
27 |
+
)
|
28 |
+
text_encoder = FlaxCLIPTextModel.from_pretrained(
|
29 |
+
MODEL_NAME,
|
30 |
+
subfolder="text_encoder",
|
31 |
+
revision="flax",
|
32 |
+
dtype=jnp.float32,
|
33 |
+
)
|
34 |
+
tokenizer = CLIPTokenizer.from_pretrained(
|
35 |
+
MODEL_NAME,
|
36 |
+
subfolder="tokenizer",
|
37 |
+
revision="flax",
|
38 |
+
dtype=jnp.float32,
|
39 |
+
)
|
40 |
+
|
41 |
+
controlnet = FlaxControlNetModel(
|
42 |
+
in_channels=unet.config.in_channels,
|
43 |
+
down_block_types=unet.config.down_block_types,
|
44 |
+
only_cross_attention=unet.config.only_cross_attention,
|
45 |
+
block_out_channels=unet.config.block_out_channels,
|
46 |
+
layers_per_block=unet.config.layers_per_block,
|
47 |
+
attention_head_dim=unet.config.attention_head_dim,
|
48 |
+
cross_attention_dim=unet.config.cross_attention_dim,
|
49 |
+
use_linear_projection=unet.config.use_linear_projection,
|
50 |
+
flip_sin_to_cos=unet.config.flip_sin_to_cos,
|
51 |
+
freq_shift=unet.config.freq_shift,
|
52 |
+
)
|
53 |
+
scheduler = FlaxDDIMScheduler(
|
54 |
+
num_train_timesteps=1000,
|
55 |
+
beta_start=0.00085,
|
56 |
+
beta_end=0.012,
|
57 |
+
beta_schedule="scaled_linear",
|
58 |
+
trained_betas=None,
|
59 |
+
set_alpha_to_one=True,
|
60 |
+
steps_offset=0,
|
61 |
+
)
|
62 |
+
scheduler_state = scheduler.create_state()
|
63 |
+
|
64 |
+
pipeline = FlaxStableDiffusionControlNetPipeline(
|
65 |
+
vae,
|
66 |
+
text_encoder,
|
67 |
+
tokenizer,
|
68 |
+
unet,
|
69 |
+
controlnet,
|
70 |
+
scheduler,
|
71 |
+
None,
|
72 |
+
None,
|
73 |
+
dtype=jnp.float32,
|
74 |
+
)
|
75 |
+
controlnet_params = checkpoints.restore_checkpoint("experiments/checkpoint_72001", target=None)
|
76 |
+
|
77 |
+
pipeline_params = {
|
78 |
+
"vae": vae_params,
|
79 |
+
"unet": unet_params,
|
80 |
+
"text_encoder": text_encoder.params,
|
81 |
+
"scheduler": scheduler_state,
|
82 |
+
"controlnet": controlnet_params,
|
83 |
+
}
|
84 |
+
pipeline_params = replicate(pipeline_params)
|
85 |
+
|
86 |
+
def infer(seed, prompt, negative_prompt, steps, cfgr):
|
87 |
+
rng = jax.random.PRNGKey(int(seed))
|
88 |
+
|
89 |
+
num_samples = jax.device_count()
|
90 |
+
rng = jax.random.split(rng, num_samples)
|
91 |
+
|
92 |
+
prompt_ids = pipeline.prepare_text_inputs([prompt] * num_samples)
|
93 |
+
negative_prompt_ids = pipeline.prepare_text_inputs([negative_prompt] * num_samples)
|
94 |
+
|
95 |
+
prompt_ids = shard(prompt_ids)
|
96 |
+
negative_prompt_ids = shard(negative_prompt_ids)
|
97 |
+
|
98 |
+
output = pipeline(
|
99 |
+
prompt_ids=prompt_ids,
|
100 |
+
image=None,
|
101 |
+
params=pipeline_params,
|
102 |
+
prng_seed=rng,
|
103 |
+
num_inference_steps=int(steps),
|
104 |
+
guidance_scale=float(cfgr),
|
105 |
+
neg_prompt_ids=negative_prompt_ids,
|
106 |
+
jit=True,
|
107 |
+
).images
|
108 |
+
|
109 |
+
output_images = pipeline.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
|
110 |
+
return output_images
|
111 |
+
|
112 |
+
with gr.Blocks(theme='gradio/soft') as demo:
|
113 |
+
gr.Markdown("## Parameter-efficient text-to-image distillation")
|
114 |
+
gr.Markdown("[\[Paper\]](https://arxiv.org/abs/2310.01407) [\[Project Page\]](https://fast-codi.github.io)")
|
115 |
+
|
116 |
+
with gr.Tab("CoDi on Text-to-Image"):
|
117 |
+
|
118 |
+
with gr.Row():
|
119 |
+
with gr.Column():
|
120 |
+
prompt_input = gr.Textbox(label="Prompt")
|
121 |
+
negative_prompt = gr.Textbox(label="Negative Prompt", value="monochrome, lowres, bad anatomy, worst quality, low quality")
|
122 |
+
seed = gr.Number(label="Seed", value=0)
|
123 |
+
output = gr.Gallery(label="Output Images")
|
124 |
+
|
125 |
+
with gr.Row():
|
126 |
+
num_inference_steps = gr.Slider(2, 50, value=4, step=1, label="Steps")
|
127 |
+
guidance_scale = gr.Slider(2.0, 14.0, value=7.5, step=0.5, label='Guidance Scale')
|
128 |
+
submit_btn = gr.Button(value = "Submit")
|
129 |
+
inputs = [
|
130 |
+
seed,
|
131 |
+
prompt_input,
|
132 |
+
negative_prompt,
|
133 |
+
num_inference_steps,
|
134 |
+
guidance_scale
|
135 |
+
]
|
136 |
+
submit_btn.click(fn=infer, inputs=inputs, outputs=[output])
|
137 |
+
|
138 |
+
with gr.Row():
|
139 |
+
gr.Examples(
|
140 |
+
examples=["oranges", "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"],
|
141 |
+
inputs=prompt_input,
|
142 |
+
fn=infer
|
143 |
+
)
|
144 |
+
|
145 |
+
demo.launch(max_threads=1, share=True)
|
checkpoint_72001/_METADATA
ADDED
The diff for this file is too large to render.
See raw diff
|
|
checkpoint_72001/checkpoint
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c5bf57d86d67db04cfa568f7c0399ea91d0ee5a9a4d35bda9d07dae370b04b89
|
3 |
+
size 1445128798
|
codi/controlnet_flax.py
ADDED
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Optional, Tuple, Union
|
15 |
+
|
16 |
+
import flax
|
17 |
+
import flax.linen as nn
|
18 |
+
import jax
|
19 |
+
import jax.numpy as jnp
|
20 |
+
from flax.core.frozen_dict import FrozenDict
|
21 |
+
|
22 |
+
from diffusers.configuration_utils import ConfigMixin, flax_register_to_config
|
23 |
+
from diffusers.utils import BaseOutput
|
24 |
+
from diffusers.models.embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
|
25 |
+
from diffusers.models.modeling_flax_utils import FlaxModelMixin
|
26 |
+
from diffusers.models.unets.unet_2d_blocks_flax import (
|
27 |
+
FlaxCrossAttnDownBlock2D,
|
28 |
+
FlaxDownBlock2D,
|
29 |
+
FlaxUNetMidBlock2DCrossAttn,
|
30 |
+
)
|
31 |
+
|
32 |
+
|
33 |
+
@flax.struct.dataclass
|
34 |
+
class FlaxControlNetOutput(BaseOutput):
|
35 |
+
"""
|
36 |
+
The output of [`FlaxControlNetModel`].
|
37 |
+
|
38 |
+
Args:
|
39 |
+
down_block_res_samples (`jnp.ndarray`):
|
40 |
+
mid_block_res_sample (`jnp.ndarray`):
|
41 |
+
"""
|
42 |
+
|
43 |
+
down_block_res_samples: jnp.ndarray
|
44 |
+
mid_block_res_sample: jnp.ndarray
|
45 |
+
|
46 |
+
|
47 |
+
class FlaxControlNetConditioningEmbedding(nn.Module):
|
48 |
+
conditioning_embedding_channels: int
|
49 |
+
block_out_channels: Tuple[int, ...] = (16, 32, 96, 256)
|
50 |
+
dtype: jnp.dtype = jnp.float32
|
51 |
+
|
52 |
+
def setup(self) -> None:
|
53 |
+
self.conv_in = nn.Conv(
|
54 |
+
self.block_out_channels[0],
|
55 |
+
kernel_size=(3, 3),
|
56 |
+
padding=((1, 1), (1, 1)),
|
57 |
+
dtype=self.dtype,
|
58 |
+
)
|
59 |
+
|
60 |
+
blocks = []
|
61 |
+
for i in range(len(self.block_out_channels) - 1):
|
62 |
+
channel_in = self.block_out_channels[i]
|
63 |
+
channel_out = self.block_out_channels[i + 1]
|
64 |
+
conv1 = nn.Conv(
|
65 |
+
channel_in,
|
66 |
+
kernel_size=(3, 3),
|
67 |
+
padding=((1, 1), (1, 1)),
|
68 |
+
dtype=self.dtype,
|
69 |
+
)
|
70 |
+
blocks.append(conv1)
|
71 |
+
conv2 = nn.Conv(
|
72 |
+
channel_out,
|
73 |
+
kernel_size=(3, 3),
|
74 |
+
strides=(2, 2),
|
75 |
+
padding=((1, 1), (1, 1)),
|
76 |
+
dtype=self.dtype,
|
77 |
+
)
|
78 |
+
blocks.append(conv2)
|
79 |
+
self.blocks = blocks
|
80 |
+
|
81 |
+
self.conv_out = nn.Conv(
|
82 |
+
self.conditioning_embedding_channels,
|
83 |
+
kernel_size=(3, 3),
|
84 |
+
padding=((1, 1), (1, 1)),
|
85 |
+
kernel_init=nn.initializers.zeros_init(),
|
86 |
+
bias_init=nn.initializers.zeros_init(),
|
87 |
+
dtype=self.dtype,
|
88 |
+
)
|
89 |
+
|
90 |
+
def __call__(self, conditioning: jnp.ndarray) -> jnp.ndarray:
|
91 |
+
embedding = self.conv_in(conditioning)
|
92 |
+
embedding = nn.silu(embedding)
|
93 |
+
|
94 |
+
for block in self.blocks:
|
95 |
+
embedding = block(embedding)
|
96 |
+
embedding = nn.silu(embedding)
|
97 |
+
|
98 |
+
embedding = self.conv_out(embedding)
|
99 |
+
|
100 |
+
return embedding
|
101 |
+
|
102 |
+
|
103 |
+
@flax_register_to_config
|
104 |
+
class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
105 |
+
r"""
|
106 |
+
A ControlNet model.
|
107 |
+
|
108 |
+
This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it’s generic methods
|
109 |
+
implemented for all models (such as downloading or saving).
|
110 |
+
|
111 |
+
This model is also a Flax Linen [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
|
112 |
+
subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matters related to its
|
113 |
+
general usage and behavior.
|
114 |
+
|
115 |
+
Inherent JAX features such as the following are supported:
|
116 |
+
|
117 |
+
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
|
118 |
+
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
|
119 |
+
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
|
120 |
+
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
|
121 |
+
|
122 |
+
Parameters:
|
123 |
+
sample_size (`int`, *optional*):
|
124 |
+
The size of the input sample.
|
125 |
+
in_channels (`int`, *optional*, defaults to 4):
|
126 |
+
The number of channels in the input sample.
|
127 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`):
|
128 |
+
The tuple of downsample blocks to use.
|
129 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
130 |
+
The tuple of output channels for each block.
|
131 |
+
layers_per_block (`int`, *optional*, defaults to 2):
|
132 |
+
The number of layers per block.
|
133 |
+
attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8):
|
134 |
+
The dimension of the attention heads.
|
135 |
+
num_attention_heads (`int` or `Tuple[int]`, *optional*):
|
136 |
+
The number of attention heads.
|
137 |
+
cross_attention_dim (`int`, *optional*, defaults to 768):
|
138 |
+
The dimension of the cross attention features.
|
139 |
+
dropout (`float`, *optional*, defaults to 0):
|
140 |
+
Dropout probability for down, up and bottleneck blocks.
|
141 |
+
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
|
142 |
+
Whether to flip the sin to cos in the time embedding.
|
143 |
+
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
|
144 |
+
controlnet_conditioning_channel_order (`str`, *optional*, defaults to `rgb`):
|
145 |
+
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
|
146 |
+
conditioning_embedding_out_channels (`tuple`, *optional*, defaults to `(16, 32, 96, 256)`):
|
147 |
+
The tuple of output channel for each block in the `conditioning_embedding` layer.
|
148 |
+
"""
|
149 |
+
|
150 |
+
sample_size: int = 32
|
151 |
+
in_channels: int = 4
|
152 |
+
down_block_types: Tuple[str, ...] = (
|
153 |
+
"CrossAttnDownBlock2D",
|
154 |
+
"CrossAttnDownBlock2D",
|
155 |
+
"CrossAttnDownBlock2D",
|
156 |
+
"DownBlock2D",
|
157 |
+
)
|
158 |
+
only_cross_attention: Union[bool, Tuple[bool, ...]] = False
|
159 |
+
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280)
|
160 |
+
layers_per_block: int = 2
|
161 |
+
attention_head_dim: Union[int, Tuple[int, ...]] = 8
|
162 |
+
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None
|
163 |
+
cross_attention_dim: int = 1280
|
164 |
+
dropout: float = 0.0
|
165 |
+
use_linear_projection: bool = False
|
166 |
+
dtype: jnp.dtype = jnp.float32
|
167 |
+
flip_sin_to_cos: bool = True
|
168 |
+
freq_shift: int = 0
|
169 |
+
controlnet_conditioning_channel_order: str = "rgb"
|
170 |
+
conditioning_embedding_out_channels: Tuple[int, ...] = (16, 32, 96, 256)
|
171 |
+
|
172 |
+
def init_weights(self, rng: jax.Array) -> FrozenDict:
|
173 |
+
# init input tensors
|
174 |
+
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
|
175 |
+
sample = jnp.zeros(sample_shape, dtype=jnp.float32)
|
176 |
+
timesteps = jnp.ones((1,), dtype=jnp.int32)
|
177 |
+
encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32)
|
178 |
+
controlnet_cond_shape = (1, 3, self.sample_size * 8, self.sample_size * 8)
|
179 |
+
controlnet_cond = jnp.zeros(controlnet_cond_shape, dtype=jnp.float32)
|
180 |
+
|
181 |
+
params_rng, dropout_rng = jax.random.split(rng)
|
182 |
+
rngs = {"params": params_rng, "dropout": dropout_rng}
|
183 |
+
|
184 |
+
return self.init(rngs, sample, timesteps, encoder_hidden_states, controlnet_cond)["params"]
|
185 |
+
|
186 |
+
def setup(self) -> None:
|
187 |
+
block_out_channels = self.block_out_channels
|
188 |
+
time_embed_dim = block_out_channels[0] * 4
|
189 |
+
|
190 |
+
# If `num_attention_heads` is not defined (which is the case for most models)
|
191 |
+
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
192 |
+
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
193 |
+
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
194 |
+
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
195 |
+
# which is why we correct for the naming here.
|
196 |
+
num_attention_heads = self.num_attention_heads or self.attention_head_dim
|
197 |
+
|
198 |
+
# input
|
199 |
+
self.conv_in = nn.Conv(
|
200 |
+
block_out_channels[0],
|
201 |
+
kernel_size=(3, 3),
|
202 |
+
strides=(1, 1),
|
203 |
+
padding=((1, 1), (1, 1)),
|
204 |
+
dtype=self.dtype,
|
205 |
+
)
|
206 |
+
|
207 |
+
# time
|
208 |
+
self.time_proj = FlaxTimesteps(
|
209 |
+
block_out_channels[0], flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.config.freq_shift
|
210 |
+
)
|
211 |
+
self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
|
212 |
+
|
213 |
+
self.controlnet_cond_embedding = FlaxControlNetConditioningEmbedding(
|
214 |
+
conditioning_embedding_channels=block_out_channels[0],
|
215 |
+
block_out_channels=self.conditioning_embedding_out_channels,
|
216 |
+
)
|
217 |
+
|
218 |
+
only_cross_attention = self.only_cross_attention
|
219 |
+
if isinstance(only_cross_attention, bool):
|
220 |
+
only_cross_attention = (only_cross_attention,) * len(self.down_block_types)
|
221 |
+
|
222 |
+
if isinstance(num_attention_heads, int):
|
223 |
+
num_attention_heads = (num_attention_heads,) * len(self.down_block_types)
|
224 |
+
|
225 |
+
# down
|
226 |
+
down_blocks = []
|
227 |
+
controlnet_down_blocks = []
|
228 |
+
|
229 |
+
output_channel = block_out_channels[0]
|
230 |
+
|
231 |
+
controlnet_block = nn.Conv(
|
232 |
+
output_channel,
|
233 |
+
kernel_size=(1, 1),
|
234 |
+
padding="VALID",
|
235 |
+
kernel_init=nn.initializers.zeros_init(),
|
236 |
+
bias_init=nn.initializers.zeros_init(),
|
237 |
+
dtype=self.dtype,
|
238 |
+
)
|
239 |
+
controlnet_down_blocks.append(controlnet_block)
|
240 |
+
|
241 |
+
for i, down_block_type in enumerate(self.down_block_types):
|
242 |
+
input_channel = output_channel
|
243 |
+
output_channel = block_out_channels[i]
|
244 |
+
is_final_block = i == len(block_out_channels) - 1
|
245 |
+
|
246 |
+
if down_block_type == "CrossAttnDownBlock2D":
|
247 |
+
down_block = FlaxCrossAttnDownBlock2D(
|
248 |
+
in_channels=input_channel,
|
249 |
+
out_channels=output_channel,
|
250 |
+
dropout=self.dropout,
|
251 |
+
num_layers=self.layers_per_block,
|
252 |
+
num_attention_heads=num_attention_heads[i],
|
253 |
+
add_downsample=not is_final_block,
|
254 |
+
use_linear_projection=self.use_linear_projection,
|
255 |
+
only_cross_attention=only_cross_attention[i],
|
256 |
+
dtype=self.dtype,
|
257 |
+
)
|
258 |
+
else:
|
259 |
+
down_block = FlaxDownBlock2D(
|
260 |
+
in_channels=input_channel,
|
261 |
+
out_channels=output_channel,
|
262 |
+
dropout=self.dropout,
|
263 |
+
num_layers=self.layers_per_block,
|
264 |
+
add_downsample=not is_final_block,
|
265 |
+
dtype=self.dtype,
|
266 |
+
)
|
267 |
+
|
268 |
+
down_blocks.append(down_block)
|
269 |
+
|
270 |
+
for _ in range(self.layers_per_block):
|
271 |
+
controlnet_block = nn.Conv(
|
272 |
+
output_channel,
|
273 |
+
kernel_size=(1, 1),
|
274 |
+
padding="VALID",
|
275 |
+
kernel_init=nn.initializers.zeros_init(),
|
276 |
+
bias_init=nn.initializers.zeros_init(),
|
277 |
+
dtype=self.dtype,
|
278 |
+
)
|
279 |
+
controlnet_down_blocks.append(controlnet_block)
|
280 |
+
|
281 |
+
if not is_final_block:
|
282 |
+
controlnet_block = nn.Conv(
|
283 |
+
output_channel,
|
284 |
+
kernel_size=(1, 1),
|
285 |
+
padding="VALID",
|
286 |
+
kernel_init=nn.initializers.zeros_init(),
|
287 |
+
bias_init=nn.initializers.zeros_init(),
|
288 |
+
dtype=self.dtype,
|
289 |
+
)
|
290 |
+
controlnet_down_blocks.append(controlnet_block)
|
291 |
+
|
292 |
+
self.down_blocks = down_blocks
|
293 |
+
self.controlnet_down_blocks = controlnet_down_blocks
|
294 |
+
|
295 |
+
# mid
|
296 |
+
mid_block_channel = block_out_channels[-1]
|
297 |
+
self.mid_block = FlaxUNetMidBlock2DCrossAttn(
|
298 |
+
in_channels=mid_block_channel,
|
299 |
+
dropout=self.dropout,
|
300 |
+
num_attention_heads=num_attention_heads[-1],
|
301 |
+
use_linear_projection=self.use_linear_projection,
|
302 |
+
dtype=self.dtype,
|
303 |
+
)
|
304 |
+
|
305 |
+
self.controlnet_mid_block = nn.Conv(
|
306 |
+
mid_block_channel,
|
307 |
+
kernel_size=(1, 1),
|
308 |
+
padding="VALID",
|
309 |
+
kernel_init=nn.initializers.zeros_init(),
|
310 |
+
bias_init=nn.initializers.zeros_init(),
|
311 |
+
dtype=self.dtype,
|
312 |
+
)
|
313 |
+
|
314 |
+
def __call__(
|
315 |
+
self,
|
316 |
+
sample: jnp.ndarray,
|
317 |
+
timesteps: Union[jnp.ndarray, float, int],
|
318 |
+
encoder_hidden_states: jnp.ndarray,
|
319 |
+
controlnet_cond: jnp.ndarray,
|
320 |
+
conditioning_scale: float = 1.0,
|
321 |
+
return_dict: bool = True,
|
322 |
+
train: bool = False,
|
323 |
+
) -> Union[FlaxControlNetOutput, Tuple[Tuple[jnp.ndarray, ...], jnp.ndarray]]:
|
324 |
+
r"""
|
325 |
+
Args:
|
326 |
+
sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor
|
327 |
+
timestep (`jnp.ndarray` or `float` or `int`): timesteps
|
328 |
+
encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states
|
329 |
+
controlnet_cond (`jnp.ndarray`): (batch, channel, height, width) the conditional input tensor
|
330 |
+
conditioning_scale (`float`, *optional*, defaults to `1.0`): the scale factor for controlnet outputs
|
331 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
332 |
+
Whether or not to return a [`models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
|
333 |
+
plain tuple.
|
334 |
+
train (`bool`, *optional*, defaults to `False`):
|
335 |
+
Use deterministic functions and disable dropout when not training.
|
336 |
+
|
337 |
+
Returns:
|
338 |
+
[`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
|
339 |
+
[`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a
|
340 |
+
`tuple`. When returning a tuple, the first element is the sample tensor.
|
341 |
+
"""
|
342 |
+
channel_order = self.controlnet_conditioning_channel_order
|
343 |
+
|
344 |
+
# 1. time
|
345 |
+
if not isinstance(timesteps, jnp.ndarray):
|
346 |
+
timesteps = jnp.array([timesteps], dtype=jnp.int32)
|
347 |
+
elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0:
|
348 |
+
timesteps = timesteps.astype(dtype=jnp.float32)
|
349 |
+
timesteps = jnp.expand_dims(timesteps, 0)
|
350 |
+
|
351 |
+
t_emb = self.time_proj(timesteps)
|
352 |
+
t_emb = self.time_embedding(t_emb)
|
353 |
+
|
354 |
+
# 2. pre-process
|
355 |
+
sample = jnp.transpose(sample, (0, 2, 3, 1))
|
356 |
+
sample = self.conv_in(sample)
|
357 |
+
|
358 |
+
if controlnet_cond is not None:
|
359 |
+
if channel_order == "bgr":
|
360 |
+
controlnet_cond = jnp.flip(controlnet_cond, axis=1)
|
361 |
+
controlnet_cond = jnp.transpose(controlnet_cond, (0, 2, 3, 1))
|
362 |
+
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
|
363 |
+
sample += controlnet_cond
|
364 |
+
|
365 |
+
# 3. down
|
366 |
+
down_block_res_samples = (sample,)
|
367 |
+
for down_block in self.down_blocks:
|
368 |
+
if isinstance(down_block, FlaxCrossAttnDownBlock2D):
|
369 |
+
sample, res_samples = down_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
|
370 |
+
else:
|
371 |
+
sample, res_samples = down_block(sample, t_emb, deterministic=not train)
|
372 |
+
down_block_res_samples += res_samples
|
373 |
+
|
374 |
+
# 4. mid
|
375 |
+
sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
|
376 |
+
|
377 |
+
# 5. contronet blocks
|
378 |
+
controlnet_down_block_res_samples = ()
|
379 |
+
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
|
380 |
+
down_block_res_sample = controlnet_block(down_block_res_sample)
|
381 |
+
controlnet_down_block_res_samples += (down_block_res_sample,)
|
382 |
+
|
383 |
+
down_block_res_samples = controlnet_down_block_res_samples
|
384 |
+
|
385 |
+
mid_block_res_sample = self.controlnet_mid_block(sample)
|
386 |
+
|
387 |
+
# 6. scaling
|
388 |
+
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
|
389 |
+
mid_block_res_sample *= conditioning_scale
|
390 |
+
|
391 |
+
if not return_dict:
|
392 |
+
return (down_block_res_samples, mid_block_res_sample)
|
393 |
+
|
394 |
+
return FlaxControlNetOutput(
|
395 |
+
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
|
396 |
+
)
|
codi/pipeline_flax_controlnet.py
ADDED
@@ -0,0 +1,610 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import warnings
|
16 |
+
from functools import partial
|
17 |
+
from typing import Dict, List, Optional, Union
|
18 |
+
|
19 |
+
import jax
|
20 |
+
import jax.numpy as jnp
|
21 |
+
import numpy as np
|
22 |
+
from flax.core.frozen_dict import FrozenDict
|
23 |
+
from flax.jax_utils import unreplicate
|
24 |
+
from flax.training.common_utils import shard
|
25 |
+
from PIL import Image
|
26 |
+
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
|
27 |
+
|
28 |
+
from diffusers.models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
|
29 |
+
from diffusers.schedulers import (
|
30 |
+
FlaxDDIMScheduler,
|
31 |
+
FlaxDPMSolverMultistepScheduler,
|
32 |
+
FlaxLMSDiscreteScheduler,
|
33 |
+
FlaxPNDMScheduler,
|
34 |
+
)
|
35 |
+
from diffusers.utils import PIL_INTERPOLATION, logging, replace_example_docstring
|
36 |
+
from diffusers.pipelines.pipeline_flax_utils import FlaxDiffusionPipeline
|
37 |
+
from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionPipelineOutput
|
38 |
+
from diffusers.pipelines.stable_diffusion.safety_checker_flax import FlaxStableDiffusionSafetyChecker
|
39 |
+
from diffusers.schedulers.scheduling_utils_flax import get_sqrt_alpha_prod
|
40 |
+
from diffusers.schedulers.scheduling_utils_flax import broadcast_to_shape_from_left
|
41 |
+
|
42 |
+
from codi.controlnet_flax import FlaxControlNetModel
|
43 |
+
|
44 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
45 |
+
|
46 |
+
# Set to True to use python for loop instead of jax.fori_loop for easier debugging
|
47 |
+
DEBUG = False
|
48 |
+
|
49 |
+
EXAMPLE_DOC_STRING = """
|
50 |
+
Examples:
|
51 |
+
```py
|
52 |
+
>>> import jax
|
53 |
+
>>> import numpy as np
|
54 |
+
>>> import jax.numpy as jnp
|
55 |
+
>>> from flax.jax_utils import replicate
|
56 |
+
>>> from flax.training.common_utils import shard
|
57 |
+
>>> from diffusers.utils import load_image, make_image_grid
|
58 |
+
>>> from PIL import Image
|
59 |
+
>>> from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
|
60 |
+
|
61 |
+
|
62 |
+
>>> def create_key(seed=0):
|
63 |
+
... return jax.random.PRNGKey(seed)
|
64 |
+
|
65 |
+
|
66 |
+
>>> rng = create_key(0)
|
67 |
+
|
68 |
+
>>> # get canny image
|
69 |
+
>>> canny_image = load_image(
|
70 |
+
... "https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/blog_post_cell_10_output_0.jpeg"
|
71 |
+
... )
|
72 |
+
|
73 |
+
>>> prompts = "best quality, extremely detailed"
|
74 |
+
>>> negative_prompts = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
75 |
+
|
76 |
+
>>> # load control net and stable diffusion v1-5
|
77 |
+
>>> controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
|
78 |
+
... "lllyasviel/sd-controlnet-canny", from_pt=True, dtype=jnp.float32
|
79 |
+
... )
|
80 |
+
>>> pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
|
81 |
+
... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.float32
|
82 |
+
... )
|
83 |
+
>>> params["controlnet"] = controlnet_params
|
84 |
+
|
85 |
+
>>> num_samples = jax.device_count()
|
86 |
+
>>> rng = jax.random.split(rng, jax.device_count())
|
87 |
+
|
88 |
+
>>> prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
|
89 |
+
>>> negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
|
90 |
+
>>> processed_image = pipe.prepare_image_inputs([canny_image] * num_samples)
|
91 |
+
|
92 |
+
>>> p_params = replicate(params)
|
93 |
+
>>> prompt_ids = shard(prompt_ids)
|
94 |
+
>>> negative_prompt_ids = shard(negative_prompt_ids)
|
95 |
+
>>> processed_image = shard(processed_image)
|
96 |
+
|
97 |
+
>>> output = pipe(
|
98 |
+
... prompt_ids=prompt_ids,
|
99 |
+
... image=processed_image,
|
100 |
+
... params=p_params,
|
101 |
+
... prng_seed=rng,
|
102 |
+
... num_inference_steps=50,
|
103 |
+
... neg_prompt_ids=negative_prompt_ids,
|
104 |
+
... jit=True,
|
105 |
+
... ).images
|
106 |
+
|
107 |
+
>>> output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
|
108 |
+
>>> output_images = make_image_grid(output_images, num_samples // 4, 4)
|
109 |
+
>>> output_images.save("generated_image.png")
|
110 |
+
```
|
111 |
+
"""
|
112 |
+
|
113 |
+
def scalings_for_boundary_conditions(
|
114 |
+
timestep, sigma_data=0.5, timestep_scaling=10.0
|
115 |
+
):
|
116 |
+
scaled_timestep = timestep * timestep_scaling
|
117 |
+
c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
|
118 |
+
c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
|
119 |
+
return c_skip, c_out
|
120 |
+
|
121 |
+
class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
|
122 |
+
r"""
|
123 |
+
Flax-based pipeline for text-to-image generation using Stable Diffusion with ControlNet Guidance.
|
124 |
+
|
125 |
+
This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods
|
126 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
127 |
+
|
128 |
+
Args:
|
129 |
+
vae ([`FlaxAutoencoderKL`]):
|
130 |
+
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
131 |
+
text_encoder ([`~transformers.FlaxCLIPTextModel`]):
|
132 |
+
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
133 |
+
tokenizer ([`~transformers.CLIPTokenizer`]):
|
134 |
+
A `CLIPTokenizer` to tokenize text.
|
135 |
+
unet ([`FlaxUNet2DConditionModel`]):
|
136 |
+
A `FlaxUNet2DConditionModel` to denoise the encoded image latents.
|
137 |
+
controlnet ([`FlaxControlNetModel`]:
|
138 |
+
Provides additional conditioning to the `unet` during the denoising process.
|
139 |
+
scheduler ([`SchedulerMixin`]):
|
140 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
141 |
+
[`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or
|
142 |
+
[`FlaxDPMSolverMultistepScheduler`].
|
143 |
+
safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
|
144 |
+
Classification module that estimates whether generated images could be considered offensive or harmful.
|
145 |
+
Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
|
146 |
+
about a model's potential harms.
|
147 |
+
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
148 |
+
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
149 |
+
"""
|
150 |
+
|
151 |
+
def __init__(
|
152 |
+
self,
|
153 |
+
vae: FlaxAutoencoderKL,
|
154 |
+
text_encoder: FlaxCLIPTextModel,
|
155 |
+
tokenizer: CLIPTokenizer,
|
156 |
+
unet: FlaxUNet2DConditionModel,
|
157 |
+
controlnet: FlaxControlNetModel,
|
158 |
+
scheduler: Union[
|
159 |
+
FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler
|
160 |
+
],
|
161 |
+
safety_checker: FlaxStableDiffusionSafetyChecker,
|
162 |
+
feature_extractor: CLIPFeatureExtractor,
|
163 |
+
dtype: jnp.dtype = jnp.float32,
|
164 |
+
):
|
165 |
+
super().__init__()
|
166 |
+
self.dtype = dtype
|
167 |
+
|
168 |
+
if safety_checker is None:
|
169 |
+
logger.warn(
|
170 |
+
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
171 |
+
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
172 |
+
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
173 |
+
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
174 |
+
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
175 |
+
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
176 |
+
)
|
177 |
+
|
178 |
+
self.register_modules(
|
179 |
+
vae=vae,
|
180 |
+
text_encoder=text_encoder,
|
181 |
+
tokenizer=tokenizer,
|
182 |
+
unet=unet,
|
183 |
+
controlnet=controlnet,
|
184 |
+
scheduler=scheduler,
|
185 |
+
safety_checker=safety_checker,
|
186 |
+
feature_extractor=feature_extractor,
|
187 |
+
)
|
188 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
189 |
+
|
190 |
+
def prepare_text_inputs(self, prompt: Union[str, List[str]]):
|
191 |
+
if not isinstance(prompt, (str, list)):
|
192 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
193 |
+
|
194 |
+
text_input = self.tokenizer(
|
195 |
+
prompt,
|
196 |
+
padding="max_length",
|
197 |
+
max_length=self.tokenizer.model_max_length,
|
198 |
+
truncation=True,
|
199 |
+
return_tensors="np",
|
200 |
+
)
|
201 |
+
|
202 |
+
return text_input.input_ids
|
203 |
+
|
204 |
+
def prepare_image_inputs(self, image: Union[Image.Image, List[Image.Image]]):
|
205 |
+
if not isinstance(image, (Image.Image, list)):
|
206 |
+
raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}")
|
207 |
+
|
208 |
+
if isinstance(image, Image.Image):
|
209 |
+
image = [image]
|
210 |
+
|
211 |
+
processed_images = jnp.concatenate([preprocess(img, jnp.float32) for img in image])
|
212 |
+
|
213 |
+
return processed_images
|
214 |
+
|
215 |
+
def _get_has_nsfw_concepts(self, features, params):
|
216 |
+
has_nsfw_concepts = self.safety_checker(features, params)
|
217 |
+
return has_nsfw_concepts
|
218 |
+
|
219 |
+
def _run_safety_checker(self, images, safety_model_params, jit=False):
|
220 |
+
# safety_model_params should already be replicated when jit is True
|
221 |
+
pil_images = [Image.fromarray(image) for image in images]
|
222 |
+
features = self.feature_extractor(pil_images, return_tensors="np").pixel_values
|
223 |
+
|
224 |
+
if jit:
|
225 |
+
features = shard(features)
|
226 |
+
has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params)
|
227 |
+
has_nsfw_concepts = unshard(has_nsfw_concepts)
|
228 |
+
safety_model_params = unreplicate(safety_model_params)
|
229 |
+
else:
|
230 |
+
has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params)
|
231 |
+
|
232 |
+
images_was_copied = False
|
233 |
+
for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
|
234 |
+
if has_nsfw_concept:
|
235 |
+
if not images_was_copied:
|
236 |
+
images_was_copied = True
|
237 |
+
images = images.copy()
|
238 |
+
|
239 |
+
images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image
|
240 |
+
|
241 |
+
if any(has_nsfw_concepts):
|
242 |
+
warnings.warn(
|
243 |
+
"Potential NSFW content was detected in one or more images. A black image will be returned"
|
244 |
+
" instead. Try again with a different prompt and/or seed."
|
245 |
+
)
|
246 |
+
|
247 |
+
return images, has_nsfw_concepts
|
248 |
+
|
249 |
+
def _generate(
|
250 |
+
self,
|
251 |
+
prompt_ids: jnp.ndarray,
|
252 |
+
image: jnp.ndarray,
|
253 |
+
params: Union[Dict, FrozenDict],
|
254 |
+
prng_seed: jax.Array,
|
255 |
+
num_inference_steps: int,
|
256 |
+
guidance_scale: float,
|
257 |
+
latents: Optional[jnp.ndarray] = None,
|
258 |
+
neg_prompt_ids: Optional[jnp.ndarray] = None,
|
259 |
+
controlnet_conditioning_scale: float = 1.0,
|
260 |
+
height: int = 512,
|
261 |
+
width: int = 512,
|
262 |
+
distill_timestep_scaling: int = 10,
|
263 |
+
distill_learning_steps: int = 50,
|
264 |
+
onestepode_sample_eps: str = "nprediction"
|
265 |
+
):
|
266 |
+
if image is not None:
|
267 |
+
height, width = image.shape[-2:]
|
268 |
+
|
269 |
+
if height % 64 != 0 or width % 64 != 0:
|
270 |
+
raise ValueError(f"`height` and `width` have to be divisible by 64 but are {height} and {width}.")
|
271 |
+
|
272 |
+
# get prompt text embeddings
|
273 |
+
prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0]
|
274 |
+
|
275 |
+
# TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0`
|
276 |
+
# implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0`
|
277 |
+
batch_size = prompt_ids.shape[0]
|
278 |
+
|
279 |
+
max_length = prompt_ids.shape[-1]
|
280 |
+
|
281 |
+
if neg_prompt_ids is None:
|
282 |
+
uncond_input = self.tokenizer(
|
283 |
+
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np"
|
284 |
+
).input_ids
|
285 |
+
else:
|
286 |
+
uncond_input = neg_prompt_ids
|
287 |
+
negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0]
|
288 |
+
context = jnp.concatenate([negative_prompt_embeds, prompt_embeds])
|
289 |
+
|
290 |
+
if image is not None:
|
291 |
+
image = jnp.concatenate([image] * 2)
|
292 |
+
|
293 |
+
latents_shape = (
|
294 |
+
batch_size,
|
295 |
+
self.unet.config.in_channels,
|
296 |
+
height // self.vae_scale_factor,
|
297 |
+
width // self.vae_scale_factor,
|
298 |
+
)
|
299 |
+
if latents is None:
|
300 |
+
latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32)
|
301 |
+
else:
|
302 |
+
if latents.shape != latents_shape:
|
303 |
+
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
304 |
+
|
305 |
+
def loop_body(step, args):
|
306 |
+
latents, scheduler_state = args
|
307 |
+
latents_input = jnp.concatenate([latents] * 2)
|
308 |
+
|
309 |
+
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
|
310 |
+
timestep = jnp.broadcast_to(t, latents.shape[0])
|
311 |
+
next_t = jnp.where(step < num_inference_steps -1, jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step + 1], 0)
|
312 |
+
next_timestep = jnp.broadcast_to(next_t, latents.shape[0])
|
313 |
+
|
314 |
+
c_skip, c_out = scalings_for_boundary_conditions(
|
315 |
+
timestep, timestep_scaling=distill_timestep_scaling,
|
316 |
+
)
|
317 |
+
alpha_t, sigma_t = get_sqrt_alpha_prod(
|
318 |
+
scheduler_state.common,
|
319 |
+
latents, # only used for determining shape
|
320 |
+
latents, # unused code
|
321 |
+
timestep,
|
322 |
+
)
|
323 |
+
alpha_s, sigma_s = get_sqrt_alpha_prod(
|
324 |
+
scheduler_state.common,
|
325 |
+
latents, # only used for determining shape
|
326 |
+
latents, # unused code
|
327 |
+
next_timestep,
|
328 |
+
)
|
329 |
+
|
330 |
+
# jax.debug.print("timestep {}", timestep)
|
331 |
+
# jax.debug.print("next_timestep {}", next_timestep)
|
332 |
+
# jax.debug.print("c_skip {}", c_skip)
|
333 |
+
# jax.debug.print("c_out {}", c_out)
|
334 |
+
# jax.debug.print("alpha_s {}", alpha_s.mean())
|
335 |
+
# jax.debug.print("sigma_s {}", sigma_s.mean())
|
336 |
+
|
337 |
+
c_skip = broadcast_to_shape_from_left(c_skip, latents.shape)
|
338 |
+
c_out = broadcast_to_shape_from_left(c_out, latents.shape)
|
339 |
+
|
340 |
+
latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t)
|
341 |
+
|
342 |
+
down_block_res_samples, mid_block_res_sample = self.controlnet.apply(
|
343 |
+
{"params": params["controlnet"]},
|
344 |
+
jnp.array(latents_input),
|
345 |
+
jnp.array(jnp.concatenate([timestep] * 2, axis=0), dtype=jnp.int32),
|
346 |
+
encoder_hidden_states=context,
|
347 |
+
controlnet_cond=image,
|
348 |
+
conditioning_scale=controlnet_conditioning_scale,
|
349 |
+
return_dict=False,
|
350 |
+
)
|
351 |
+
|
352 |
+
# predict the noise residual
|
353 |
+
model_pred = self.unet.apply(
|
354 |
+
{"params": params["unet"]},
|
355 |
+
jnp.array(latents_input),
|
356 |
+
jnp.array(timestep, dtype=jnp.int32),
|
357 |
+
encoder_hidden_states=context,
|
358 |
+
down_block_additional_residuals=down_block_res_samples,
|
359 |
+
mid_block_additional_residual=mid_block_res_sample,
|
360 |
+
).sample
|
361 |
+
|
362 |
+
# perform guidance
|
363 |
+
mode_pred_uncond, model_prediction_text = jnp.split(model_pred, 2, axis=0)
|
364 |
+
model_pred = mode_pred_uncond + guidance_scale * (model_prediction_text - mode_pred_uncond)
|
365 |
+
|
366 |
+
if onestepode_sample_eps == 'nprediction':
|
367 |
+
target_model_pred_x = (latents - sigma_t * model_pred ) / alpha_t
|
368 |
+
target_model_pred_epsilon = model_pred
|
369 |
+
elif args.onestepode_sample_eps == 'vprediction':
|
370 |
+
target_model_pred_epsilon = (
|
371 |
+
alpha_t * model_pred + sigma_t * latents_input
|
372 |
+
)
|
373 |
+
target_model_pred_x = (
|
374 |
+
alpha_t * latents - sigma_t * model_pred
|
375 |
+
)
|
376 |
+
elif args.onestepode_sample_eps == 'xprediction':
|
377 |
+
target_model_pred_x = model_pred
|
378 |
+
target_model_pred_epsilon = (latents - alpha_t * model_pred) / sigma_t
|
379 |
+
else:
|
380 |
+
raise NotImplementedError
|
381 |
+
|
382 |
+
target_model_pred_x = (
|
383 |
+
c_skip * latents + c_out * target_model_pred_x
|
384 |
+
)
|
385 |
+
|
386 |
+
latents = alpha_s * target_model_pred_x + sigma_s * target_model_pred_epsilon
|
387 |
+
return latents, scheduler_state
|
388 |
+
|
389 |
+
scheduler_state = params["scheduler"]
|
390 |
+
skipped_schedule = self.scheduler.num_train_timesteps // distill_learning_steps
|
391 |
+
timesteps = (jnp.arange(0, distill_learning_steps) * skipped_schedule).round()[::-1]
|
392 |
+
step_ratio = (distill_learning_steps + num_inference_steps - 1) // num_inference_steps
|
393 |
+
timesteps = timesteps[::step_ratio]
|
394 |
+
scheduler_state = scheduler_state.replace(
|
395 |
+
num_inference_steps=num_inference_steps, timesteps=timesteps
|
396 |
+
)
|
397 |
+
|
398 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
399 |
+
latents = latents * params["scheduler"].init_noise_sigma
|
400 |
+
|
401 |
+
if DEBUG:
|
402 |
+
# run with python for loop
|
403 |
+
for i in range(num_inference_steps):
|
404 |
+
latents, scheduler_state = loop_body(i, (latents, scheduler_state))
|
405 |
+
else:
|
406 |
+
latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state))
|
407 |
+
|
408 |
+
# scale and decode the image latents with vae
|
409 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
410 |
+
image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample
|
411 |
+
|
412 |
+
image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
|
413 |
+
return image
|
414 |
+
|
415 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
416 |
+
def __call__(
|
417 |
+
self,
|
418 |
+
prompt_ids: jnp.ndarray,
|
419 |
+
image: jnp.ndarray,
|
420 |
+
params: Union[Dict, FrozenDict],
|
421 |
+
prng_seed: jax.Array,
|
422 |
+
num_inference_steps: int = 50,
|
423 |
+
guidance_scale: Union[float, jnp.ndarray] = 7.5,
|
424 |
+
latents: jnp.ndarray = None,
|
425 |
+
neg_prompt_ids: jnp.ndarray = None,
|
426 |
+
controlnet_conditioning_scale: Union[float, jnp.ndarray] = 1.0,
|
427 |
+
return_dict: bool = True,
|
428 |
+
jit: bool = False,
|
429 |
+
height: int = 512,
|
430 |
+
width: int = 512,
|
431 |
+
):
|
432 |
+
r"""
|
433 |
+
The call function to the pipeline for generation.
|
434 |
+
|
435 |
+
Args:
|
436 |
+
prompt_ids (`jnp.ndarray`):
|
437 |
+
The prompt or prompts to guide the image generation.
|
438 |
+
image (`jnp.ndarray`):
|
439 |
+
Array representing the ControlNet input condition to provide guidance to the `unet` for generation.
|
440 |
+
params (`Dict` or `FrozenDict`):
|
441 |
+
Dictionary containing the model parameters/weights.
|
442 |
+
prng_seed (`jax.Array`):
|
443 |
+
Array containing random number generator key.
|
444 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
445 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
446 |
+
expense of slower inference.
|
447 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
448 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
449 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
450 |
+
latents (`jnp.ndarray`, *optional*):
|
451 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
452 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
453 |
+
array is generated by sampling using the supplied random `generator`.
|
454 |
+
controlnet_conditioning_scale (`float` or `jnp.ndarray`, *optional*, defaults to 1.0):
|
455 |
+
The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
|
456 |
+
to the residual in the original `unet`.
|
457 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
458 |
+
Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
|
459 |
+
a plain tuple.
|
460 |
+
jit (`bool`, defaults to `False`):
|
461 |
+
Whether to run `pmap` versions of the generation and safety scoring functions.
|
462 |
+
|
463 |
+
<Tip warning={true}>
|
464 |
+
|
465 |
+
This argument exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a
|
466 |
+
future release.
|
467 |
+
|
468 |
+
</Tip>
|
469 |
+
|
470 |
+
Examples:
|
471 |
+
|
472 |
+
Returns:
|
473 |
+
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`:
|
474 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] is
|
475 |
+
returned, otherwise a `tuple` is returned where the first element is a list with the generated images
|
476 |
+
and the second element is a list of `bool`s indicating whether the corresponding generated image
|
477 |
+
contains "not-safe-for-work" (nsfw) content.
|
478 |
+
"""
|
479 |
+
|
480 |
+
if image is not None:
|
481 |
+
height, width = image.shape[-2:]
|
482 |
+
|
483 |
+
if isinstance(guidance_scale, float):
|
484 |
+
# Convert to a tensor so each device gets a copy. Follow the prompt_ids for
|
485 |
+
# shape information, as they may be sharded (when `jit` is `True`), or not.
|
486 |
+
guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0])
|
487 |
+
if len(prompt_ids.shape) > 2:
|
488 |
+
# Assume sharded
|
489 |
+
guidance_scale = guidance_scale[:, None]
|
490 |
+
|
491 |
+
if isinstance(controlnet_conditioning_scale, float):
|
492 |
+
# Convert to a tensor so each device gets a copy. Follow the prompt_ids for
|
493 |
+
# shape information, as they may be sharded (when `jit` is `True`), or not.
|
494 |
+
controlnet_conditioning_scale = jnp.array([controlnet_conditioning_scale] * prompt_ids.shape[0])
|
495 |
+
if len(prompt_ids.shape) > 2:
|
496 |
+
# Assume sharded
|
497 |
+
controlnet_conditioning_scale = controlnet_conditioning_scale[:, None]
|
498 |
+
|
499 |
+
if jit:
|
500 |
+
images = _p_generate(
|
501 |
+
self,
|
502 |
+
prompt_ids,
|
503 |
+
image,
|
504 |
+
params,
|
505 |
+
prng_seed,
|
506 |
+
num_inference_steps,
|
507 |
+
guidance_scale,
|
508 |
+
latents,
|
509 |
+
neg_prompt_ids,
|
510 |
+
controlnet_conditioning_scale,
|
511 |
+
height,
|
512 |
+
width,
|
513 |
+
)
|
514 |
+
else:
|
515 |
+
images = self._generate(
|
516 |
+
prompt_ids,
|
517 |
+
image,
|
518 |
+
params,
|
519 |
+
prng_seed,
|
520 |
+
num_inference_steps,
|
521 |
+
guidance_scale,
|
522 |
+
latents,
|
523 |
+
neg_prompt_ids,
|
524 |
+
controlnet_conditioning_scale,
|
525 |
+
height,
|
526 |
+
width,
|
527 |
+
)
|
528 |
+
|
529 |
+
if self.safety_checker is not None:
|
530 |
+
safety_params = params["safety_checker"]
|
531 |
+
images_uint8_casted = (images * 255).round().astype("uint8")
|
532 |
+
num_devices, batch_size = images.shape[:2]
|
533 |
+
|
534 |
+
images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3)
|
535 |
+
images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit)
|
536 |
+
images = np.array(images)
|
537 |
+
|
538 |
+
# block images
|
539 |
+
if any(has_nsfw_concept):
|
540 |
+
for i, is_nsfw in enumerate(has_nsfw_concept):
|
541 |
+
if is_nsfw:
|
542 |
+
images[i] = np.asarray(images_uint8_casted[i])
|
543 |
+
|
544 |
+
images = images.reshape(num_devices, batch_size, height, width, 3)
|
545 |
+
else:
|
546 |
+
images = np.asarray(images)
|
547 |
+
has_nsfw_concept = False
|
548 |
+
|
549 |
+
if not return_dict:
|
550 |
+
return (images, has_nsfw_concept)
|
551 |
+
|
552 |
+
return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept)
|
553 |
+
|
554 |
+
|
555 |
+
# Static argnums are pipe, num_inference_steps. A change would trigger recompilation.
|
556 |
+
# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`).
|
557 |
+
@partial(
|
558 |
+
jax.pmap,
|
559 |
+
in_axes=(None, 0, 0, 0, 0, None, 0, 0, 0, 0, 0, 0),
|
560 |
+
static_broadcasted_argnums=(0, 5, 10, 11),
|
561 |
+
)
|
562 |
+
def _p_generate(
|
563 |
+
pipe,
|
564 |
+
prompt_ids,
|
565 |
+
image,
|
566 |
+
params,
|
567 |
+
prng_seed,
|
568 |
+
num_inference_steps,
|
569 |
+
guidance_scale,
|
570 |
+
latents,
|
571 |
+
neg_prompt_ids,
|
572 |
+
controlnet_conditioning_scale,
|
573 |
+
height,
|
574 |
+
width,
|
575 |
+
):
|
576 |
+
return pipe._generate(
|
577 |
+
prompt_ids,
|
578 |
+
image,
|
579 |
+
params,
|
580 |
+
prng_seed,
|
581 |
+
num_inference_steps,
|
582 |
+
guidance_scale,
|
583 |
+
latents,
|
584 |
+
neg_prompt_ids,
|
585 |
+
controlnet_conditioning_scale,
|
586 |
+
height,
|
587 |
+
width,
|
588 |
+
)
|
589 |
+
|
590 |
+
|
591 |
+
@partial(jax.pmap, static_broadcasted_argnums=(0,))
|
592 |
+
def _p_get_has_nsfw_concepts(pipe, features, params):
|
593 |
+
return pipe._get_has_nsfw_concepts(features, params)
|
594 |
+
|
595 |
+
|
596 |
+
def unshard(x: jnp.ndarray):
|
597 |
+
# einops.rearrange(x, 'd b ... -> (d b) ...')
|
598 |
+
num_devices, batch_size = x.shape[:2]
|
599 |
+
rest = x.shape[2:]
|
600 |
+
return x.reshape(num_devices * batch_size, *rest)
|
601 |
+
|
602 |
+
|
603 |
+
def preprocess(image, dtype):
|
604 |
+
image = image.convert("RGB")
|
605 |
+
w, h = image.size
|
606 |
+
w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64
|
607 |
+
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
|
608 |
+
image = jnp.array(image).astype(dtype) / 255.0
|
609 |
+
image = image[None].transpose(0, 3, 1, 2)
|
610 |
+
return image
|