Spaces:
Runtime error
Runtime error
# -*- coding : utf-8 -*- | |
# @FileName : attn_injection.py | |
# @Author : Ruixiang JIANG (Songrise) | |
# @Time : Mar 20, 2024 | |
# @Github : https://github.com/songrise | |
# @Description: implement attention dump and attention injection for CPSD | |
from __future__ import annotations | |
from dataclasses import dataclass | |
from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline | |
import torch | |
import torch.nn as nn | |
from torch.nn import functional as nnf | |
from diffusers.models import attention_processor | |
import einops | |
from diffusers.models import unet_2d_condition, attention, transformer_2d, resnet | |
from diffusers.models.unets import unet_2d_blocks | |
# from diffusers.models.unet_2d import CrossAttnUpBlock2D | |
from typing import Optional, List | |
T = torch.Tensor | |
import os | |
class StyleAlignedArgs: | |
share_group_norm: bool = True | |
share_layer_norm: bool = (True,) | |
share_attention: bool = True | |
adain_queries: bool = True | |
adain_keys: bool = True | |
adain_values: bool = False | |
full_attention_share: bool = False | |
shared_score_scale: float = 1.0 | |
shared_score_shift: float = 0.0 | |
only_self_level: float = 0.0 | |
def expand_first( | |
feat: T, | |
scale=1.0, | |
) -> T: | |
b = feat.shape[0] | |
feat_style = torch.stack((feat[0], feat[b // 2])).unsqueeze(1) | |
if scale == 1: | |
feat_style = feat_style.expand(2, b // 2, *feat.shape[1:]) | |
else: | |
feat_style = feat_style.repeat(1, b // 2, 1, 1, 1) | |
feat_style = torch.cat([feat_style[:, :1], scale * feat_style[:, 1:]], dim=1) | |
return feat_style.reshape(*feat.shape) | |
def concat_first(feat: T, dim=2, scale=1.0) -> T: | |
feat_style = expand_first(feat, scale=scale) | |
return torch.cat((feat, feat_style), dim=dim) | |
def calc_mean_std(feat, eps: float = 1e-5) -> tuple[T, T]: | |
feat_std = (feat.var(dim=-2, keepdims=True) + eps).sqrt() | |
feat_mean = feat.mean(dim=-2, keepdims=True) | |
return feat_mean, feat_std | |
def adain(feat: T) -> T: | |
feat_mean, feat_std = calc_mean_std(feat) | |
feat_style_mean = expand_first(feat_mean) | |
feat_style_std = expand_first(feat_std) | |
feat = (feat - feat_mean) / feat_std | |
feat = feat * feat_style_std + feat_style_mean | |
return feat | |
def my_adain(feat: T) -> T: | |
batch_size = feat.shape[0] // 2 | |
feat_mean, feat_std = calc_mean_std(feat) | |
feat_uncond_content, feat_cond_content = feat[0], feat[batch_size] | |
feat_style_mean = torch.stack((feat_mean[1], feat_mean[batch_size + 1])).unsqueeze( | |
1 | |
) | |
feat_style_mean = feat_style_mean.expand(2, batch_size, *feat_mean.shape[1:]) | |
feat_style_mean = feat_style_mean.reshape(*feat_mean.shape) # (6, D) | |
feat_style_std = torch.stack((feat_std[1], feat_std[batch_size + 1])).unsqueeze(1) | |
feat_style_std = feat_style_std.expand(2, batch_size, *feat_std.shape[1:]) | |
feat_style_std = feat_style_std.reshape(*feat_std.shape) | |
feat = (feat - feat_mean) / feat_std | |
feat = feat * feat_style_std + feat_style_mean | |
feat[0] = feat_uncond_content | |
feat[batch_size] = feat_cond_content | |
return feat | |
class DefaultAttentionProcessor(nn.Module): | |
def __init__(self): | |
super().__init__() | |
# self.processor = attention_processor.AttnProcessor2_0() | |
self.processor = attention_processor.AttnProcessor() # for torch 1.11.0 | |
def __call__( | |
self, | |
attn: attention_processor.Attention, | |
hidden_states, | |
encoder_hidden_states=None, | |
attention_mask=None, | |
**kwargs, | |
): | |
return self.processor( | |
attn, hidden_states, encoder_hidden_states, attention_mask | |
) | |
class ArtistAttentionProcessor(DefaultAttentionProcessor): | |
def __init__( | |
self, | |
inject_query: bool = True, | |
inject_key: bool = True, | |
inject_value: bool = True, | |
use_adain: bool = False, | |
name: str = None, | |
use_content_to_style_injection=False, | |
): | |
super().__init__() | |
self.inject_query = inject_query | |
self.inject_key = inject_key | |
self.inject_value = inject_value | |
self.share_enabled = True | |
self.use_adain = use_adain | |
self.__custom_name = name | |
self.content_to_style_injection = use_content_to_style_injection | |
def __call__( | |
self, | |
attn: Attention, | |
hidden_states: torch.FloatTensor, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
temb: Optional[torch.FloatTensor] = None, | |
scale: float = 1.0, | |
) -> torch.Tensor: | |
#######Code from original attention impl | |
residual = hidden_states | |
# args = () if USE_PEFT_BACKEND else (scale,) | |
args = () | |
if attn.spatial_norm is not None: | |
hidden_states = attn.spatial_norm(hidden_states, temb) | |
input_ndim = hidden_states.ndim | |
if input_ndim == 4: | |
batch_size, channel, height, width = hidden_states.shape | |
hidden_states = hidden_states.view( | |
batch_size, channel, height * width | |
).transpose(1, 2) | |
batch_size, sequence_length, _ = ( | |
hidden_states.shape | |
if encoder_hidden_states is None | |
else encoder_hidden_states.shape | |
) | |
attention_mask = attn.prepare_attention_mask( | |
attention_mask, sequence_length, batch_size | |
) | |
if attn.group_norm is not None: | |
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( | |
1, 2 | |
) | |
query = attn.to_q(hidden_states, *args) | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states | |
elif attn.norm_cross: | |
encoder_hidden_states = attn.norm_encoder_hidden_states( | |
encoder_hidden_states | |
) | |
key = attn.to_k(encoder_hidden_states, *args) | |
value = attn.to_v(encoder_hidden_states, *args) | |
######## inject begins here, here we assume the style image is always the 2nd instance in batch | |
batch_size = query.shape[0] // 2 # divide 2 since CFG is used | |
if self.share_enabled and batch_size > 1: # when == 1, no need to inject, | |
ref_q_uncond, ref_q_cond = query[1, ...].unsqueeze(0), query[ | |
batch_size + 1, ... | |
].unsqueeze(0) | |
ref_k_uncond, ref_k_cond = key[1, ...].unsqueeze(0), key[ | |
batch_size + 1, ... | |
].unsqueeze(0) | |
ref_v_uncond, ref_v_cond = value[1, ...].unsqueeze(0), value[ | |
batch_size + 1, ... | |
].unsqueeze(0) | |
if self.inject_query: | |
if self.use_adain: | |
query = my_adain(query) | |
if self.content_to_style_injection: | |
content_v_uncond = value[0, ...].unsqueeze(0) | |
content_v_cond = value[batch_size, ...].unsqueeze(0) | |
query[1] = content_v_uncond | |
query[batch_size + 1] = content_v_cond | |
else: | |
query[2] = ref_q_uncond | |
query[batch_size + 2] = ref_q_cond | |
if self.inject_key: | |
if self.use_adain: | |
key = my_adain(key) | |
else: | |
key[2] = ref_k_uncond | |
key[batch_size + 2] = ref_k_cond | |
if self.inject_value: | |
if self.use_adain: | |
value = my_adain(value) | |
else: | |
value[2] = ref_v_uncond | |
value[batch_size + 2] = ref_v_cond | |
query = attn.head_to_batch_dim(query) | |
key = attn.head_to_batch_dim(key) | |
value = attn.head_to_batch_dim(value) | |
attention_probs = attn.get_attention_scores(query, key, attention_mask) | |
# inject here, swap the attention map | |
hidden_states = torch.bmm(attention_probs, value) | |
hidden_states = attn.batch_to_head_dim(hidden_states) | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states, *args) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
if input_ndim == 4: | |
hidden_states = hidden_states.transpose(-1, -2).reshape( | |
batch_size, channel, height, width | |
) | |
if attn.residual_connection: | |
hidden_states = hidden_states + residual | |
hidden_states = hidden_states / attn.rescale_output_factor | |
return hidden_states | |
class ArtistResBlockWrapper(nn.Module): | |
def __init__( | |
self, block: resnet.ResnetBlock2D, injection_method: str, name: str = None | |
): | |
super().__init__() | |
self.block = block | |
self.output_scale_factor = self.block.output_scale_factor | |
self.injection_method = injection_method | |
self.name = name | |
def forward( | |
self, | |
input_tensor: torch.FloatTensor, | |
temb: torch.FloatTensor, | |
scale: float = 1.0, | |
): | |
if self.injection_method == "hidden": | |
feat = self.block( | |
input_tensor, temb, scale | |
) # when disentangle, feat should be [recon, uncontrolled style, controlled style] | |
batch_size = feat.shape[0] // 2 | |
if batch_size == 1: | |
return feat | |
# the features of the reconstruction | |
recon_feat_uncond, recon_feat_cond = feat[0, ...].unsqueeze(0), feat[ | |
batch_size, ... | |
].unsqueeze(0) | |
# residual | |
input_tensor = self.block.conv_shortcut(input_tensor) | |
input_content_uncond, input_content_cond = input_tensor[0, ...].unsqueeze( | |
0 | |
), input_tensor[batch_size, ...].unsqueeze(0) | |
# since feat = (input + h) / scale | |
recon_feat_uncond, recon_feat_cond = ( | |
recon_feat_uncond * self.output_scale_factor, | |
recon_feat_cond * self.output_scale_factor, | |
) | |
h_content_uncond, h_content_cond = ( | |
recon_feat_uncond - input_content_uncond, | |
recon_feat_cond - input_content_cond, | |
) | |
# only share the h, the residual is not shared | |
h_shared = torch.cat( | |
([h_content_uncond] * batch_size) + ([h_content_cond] * batch_size), | |
dim=0, | |
) | |
output_feat_shared = (input_tensor + h_shared) / self.output_scale_factor | |
# do not inject the feat for the 2nd instance, which is uncontrolled style | |
output_feat_shared[1] = feat[1] | |
output_feat_shared[batch_size + 1] = feat[batch_size + 1] | |
# uncomment to not inject content to controlled style | |
# output_feat_shared[2] = feat[2] | |
# output_feat_shared[batch_size + 2] = feat[batch_size + 2] | |
return output_feat_shared | |
else: | |
raise NotImplementedError(f"Unknown injection method {self.injection_method}") | |
class SharedResBlockWrapper(nn.Module): | |
def __init__(self, block: resnet.ResnetBlock2D): | |
super().__init__() | |
self.block = block | |
self.output_scale_factor = self.block.output_scale_factor | |
self.share_enabled = True | |
def forward( | |
self, | |
input_tensor: torch.FloatTensor, | |
temb: torch.FloatTensor, | |
scale: float = 1.0, | |
): | |
if self.share_enabled: | |
feat = self.block(input_tensor, temb, scale) | |
batch_size = feat.shape[0] // 2 | |
if batch_size == 1: | |
return feat | |
# the features of the reconstruction | |
feat_uncond, feat_cond = feat[0, ...].unsqueeze(0), feat[ | |
batch_size, ... | |
].unsqueeze(0) | |
# residual | |
input_tensor = self.block.conv_shortcut(input_tensor) | |
input_content_uncond, input_content_cond = input_tensor[0, ...].unsqueeze( | |
0 | |
), input_tensor[batch_size, ...].unsqueeze(0) | |
# since feat = (input + h) / scale | |
feat_uncond, feat_cond = ( | |
feat_uncond * self.output_scale_factor, | |
feat_cond * self.output_scale_factor, | |
) | |
h_content_uncond, h_content_cond = ( | |
feat_uncond - input_content_uncond, | |
feat_cond - input_content_cond, | |
) | |
# only share the h, the residual is not shared | |
h_shared = torch.cat( | |
([h_content_uncond] * batch_size) + ([h_content_cond] * batch_size), | |
dim=0, | |
) | |
output_shared = (input_tensor + h_shared) / self.output_scale_factor | |
return output_shared | |
else: | |
return self.block(input_tensor, temb, scale) | |
def register_attention_processors( | |
pipe, | |
base_dir: str = None, | |
disentangle: bool = False, | |
attn_mode: str = "artist", | |
resnet_mode: str = "hidden", | |
share_resblock: bool = True, | |
share_attn: bool = True, | |
share_cross_attn: bool = False, | |
share_attn_layers: Optional[int] = None, | |
share_resnet_layers: Optional[int] = None, | |
c2s_layers: Optional[int] = [0, 1], | |
share_query: bool = True, | |
share_key: bool = True, | |
share_value: bool = True, | |
use_adain: bool = False, | |
): | |
unet: unet_2d_condition.UNet2DConditionModel = pipe.unet | |
if isinstance(pipe, StableDiffusionPipeline): | |
up_blocks: List[unet_2d_blocks.CrossAttnUpBlock2D] = unet.up_blocks[ | |
1: | |
] # skip the first block, which is UpBlock2D | |
elif isinstance(pipe, StableDiffusionXLPipeline): | |
up_blocks: List[unet_2d_blocks.CrossAttnUpBlock2D] = unet.up_blocks[:-1] | |
layer_idx_attn = 0 | |
layer_idx_resnet = 0 | |
for block in up_blocks: | |
# each block should have 3 transformer layer | |
# transformer_layer : transformer_2d.Transformer2DModel | |
if share_resblock: | |
if share_resnet_layers is not None: | |
resnet_wrappers = [] | |
resnets = block.resnets | |
for resnet_block in resnets: | |
if layer_idx_resnet not in share_resnet_layers: | |
resnet_wrappers.append( | |
resnet_block | |
) # use original implementation | |
else: | |
if disentangle: | |
resnet_wrappers.append( | |
ArtistResBlockWrapper( | |
resnet_block, | |
injection_method=resnet_mode, | |
name=f"layer_{layer_idx_resnet}", | |
) | |
) | |
print( | |
f"Disentangle resnet {resnet_mode} set for layer {layer_idx_resnet}" | |
) | |
else: | |
resnet_wrappers.append(SharedResBlockWrapper(resnet_block)) | |
print( | |
f"Share resnet feature set for layer {layer_idx_resnet}" | |
) | |
layer_idx_resnet += 1 | |
block.resnets = nn.ModuleList( | |
resnet_wrappers | |
) # actually apply the change | |
if share_attn: | |
for transformer_layer in block.attentions: | |
transformer_block: attention.BasicTransformerBlock = ( | |
transformer_layer.transformer_blocks[0] | |
) | |
self_attn: attention_processor.Attention = transformer_block.attn1 | |
# cross attn does not inject | |
cross_attn: attention_processor.Attention = transformer_block.attn2 | |
if attn_mode == "artist": | |
if ( | |
share_attn_layers is not None | |
and layer_idx_attn in share_attn_layers | |
): | |
if layer_idx_attn in c2s_layers: | |
content_to_style = True | |
else: | |
content_to_style = False | |
pnp_inject_processor = ArtistAttentionProcessor( | |
inject_query=share_query, | |
inject_key=share_key, | |
inject_value=share_value, | |
use_adain=use_adain, | |
name=f"layer_{layer_idx_attn}_self", | |
use_content_to_style_injection=content_to_style, | |
) | |
self_attn.set_processor(pnp_inject_processor) | |
print( | |
f"Disentangled Pnp inject processor set for self-attention in layer {layer_idx_attn} with c2s={content_to_style}" | |
) | |
if share_cross_attn: | |
cross_attn_processor = ArtistAttentionProcessor( | |
inject_query=False, | |
inject_key=True, | |
inject_value=True, | |
use_adain=False, | |
name=f"layer_{layer_idx_attn}_cross", | |
) | |
cross_attn.set_processor(cross_attn_processor) | |
print( | |
f"Disentangled Pnp inject processor set for cross-attention in layer {layer_idx_attn}" | |
) | |
layer_idx_attn += 1 | |
def unset_attention_processors( | |
pipe, | |
unset_share_attn: bool = False, | |
unset_share_resblock: bool = False, | |
): | |
unet: unet_2d_condition.UNet2DConditionMode = pipe.unet | |
if isinstance(pipe, StableDiffusionPipeline): | |
up_blocks: List[unet_2d_blocks.CrossAttnUpBlock2D] = unet.up_blocks[ | |
1: | |
] # skip the first block, which is UpBlock2D | |
elif isinstance(pipe, StableDiffusionXLPipeline): | |
up_blocks: List[unet_2d_blocks.CrossAttnUpBlock2D] = unet.up_blocks[:-1] | |
block_idx = 1 | |
layer_idx = 0 | |
for block in up_blocks: | |
if unset_share_resblock: | |
resnet_origs = [] | |
resnets = block.resnets | |
for resnet_block in resnets: | |
if isinstance(resnet_block, SharedResBlockWrapper) or isinstance( | |
resnet_block, ArtistResBlockWrapper | |
): | |
resnet_origs.append(resnet_block.block) | |
else: | |
resnet_origs.append(resnet_block) | |
block.resnets = nn.ModuleList(resnet_origs) | |
if unset_share_attn: | |
for transformer_layer in block.attentions: | |
layer_idx += 1 | |
transformer_block: attention.BasicTransformerBlock = ( | |
transformer_layer.transformer_blocks[0] | |
) | |
self_attn: attention_processor.Attention = transformer_block.attn1 | |
cross_attn: attention_processor.Attention = transformer_block.attn2 | |
self_attn.set_processor(DefaultAttentionProcessor()) | |
cross_attn.set_processor(DefaultAttentionProcessor()) | |
block_idx += 1 | |
layer_idx = 0 | |