Live2Diff / live2diff /animatediff /models /stream_motion_module.py
leoxing1996
add demo
d16b52d
raw
history blame
7.77 kB
import torch
import torch.nn.functional as F
from einops import rearrange
from .attention import CrossAttention
from .positional_encoding import PositionalEncoding
class StreamTemporalAttention(CrossAttention):
"""
* window_size: The max length of attention window.
* sink_size: The number sink token.
* positional_rule: absolute, relative
Therefore, the seq length of temporal self-attention will be:
sink_length + cache_size
"""
def __init__(
self,
attention_mode=None,
cross_frame_attention_mode=None,
temporal_position_encoding=False,
temporal_position_encoding_max_len=32,
window_size=8,
sink_size=0,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.attention_mode = self._orig_attention_mode = attention_mode
self.is_cross_attention = kwargs["cross_attention_dim"] is not None
self.pos_encoder = PositionalEncoding(
kwargs["query_dim"],
dropout=0.0,
max_len=temporal_position_encoding_max_len,
)
self.window_size = window_size
self.sink_size = sink_size
self.cache_size = self.window_size - self.sink_size
assert self.cache_size >= 0, (
"cache_size must be greater or equal to 0. Please check your configuration. "
f"window_size: {window_size}, sink_size: {sink_size}, "
f"cache_size: {self.cache_size}"
)
self.motion_module_idx = None
def set_index(self, idx):
self.motion_module_idx = idx
@torch.no_grad()
def set_cache(self, denoising_steps_num: int):
"""
larger buffer index means cleaner latent
"""
device = next(self.parameters()).device
dtype = next(self.parameters()).dtype
# [t, 2, hw, L, c], 2 means k and v
kv_cache = torch.zeros(
denoising_steps_num,
2,
self.h * self.w,
self.window_size,
self.kv_channels,
device=device,
dtype=dtype,
)
self.denoising_steps_num = denoising_steps_num
return kv_cache
@torch.no_grad()
def prepare_pe_buffer(self):
"""In AnimateDiff, Temporal Self-attention use absolute positional encoding:
q = w_q * (x + pe) + bias
k = w_k * (x + pe) + bias
v = w_v * (x + pe) + bias
If we want to conduct relative positional encoding with kv-cache, we should pre-calcute
`w_q/k/v * pe` and then cache `w_q/k/v * x + bias`
"""
pe_list = self.pos_encoder.pe[:, : self.window_size] # [1, window_size, ch]
q_pe = F.linear(pe_list, self.to_q.weight)
k_pe = F.linear(pe_list, self.to_k.weight)
v_pe = F.linear(pe_list, self.to_v.weight)
self.register_buffer("q_pe", q_pe)
self.register_buffer("k_pe", k_pe)
self.register_buffer("v_pe", v_pe)
def prepare_qkv_full_and_cache(self, hidden_states, kv_cache, pe_idx, update_idx):
"""
hidden_states: [(N * bhw), F, c],
kv_cache: [2, N, hw, L, c]
* for warmup case: `N` should be 1 and `F` should be warmup_size (`sink_size`)
* for streaming case: `N` should be `denoising_steps_num` and `F` should be `chunk_size`
"""
q_layer = self.to_q(hidden_states)
k_layer = self.to_k(hidden_states)
v_layer = self.to_v(hidden_states)
q_layer = rearrange(q_layer, "(n bhw) f c -> n bhw f c", n=self.denoising_steps_num)
k_layer = rearrange(k_layer, "(n bhw) f c -> n bhw f c", n=self.denoising_steps_num)
v_layer = rearrange(v_layer, "(n bhw) f c -> n bhw f c", n=self.denoising_steps_num)
# onnx & trt friendly indexing
for idx in range(self.denoising_steps_num):
kv_cache[idx, 0, :, update_idx[idx]] = k_layer[idx, :, 0]
kv_cache[idx, 1, :, update_idx[idx]] = v_layer[idx, :, 0]
k_full = kv_cache[:, 0]
v_full = kv_cache[:, 1]
kv_idx = pe_idx
q_idx = torch.stack([kv_idx[idx, update_idx[idx]] for idx in range(self.denoising_steps_num)]).unsqueeze_(
1
) # [timesteps, 1]
pe_k = torch.cat(
[self.k_pe.index_select(1, kv_idx[idx]) for idx in range(self.denoising_steps_num)], dim=0
) # [n, window_size, c]
pe_v = torch.cat(
[self.v_pe.index_select(1, kv_idx[idx]) for idx in range(self.denoising_steps_num)], dim=0
) # [n, window_size, c]
pe_q = torch.cat(
[self.q_pe.index_select(1, q_idx[idx]) for idx in range(self.denoising_steps_num)], dim=0
) # [n, window_size, c]
q_layer = q_layer + pe_q.unsqueeze(1)
k_full = k_full + pe_k.unsqueeze(1)
v_full = v_full + pe_v.unsqueeze(1)
q_layer = rearrange(q_layer, "n bhw f c -> (n bhw) f c")
k_full = rearrange(k_full, "n bhw f c -> (n bhw) f c")
v_full = rearrange(v_full, "n bhw f c -> (n bhw) f c")
return q_layer, k_full, v_full
def forward(
self,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
video_length=None,
temporal_attention_mask=None,
kv_cache=None,
pe_idx=None,
update_idx=None,
*args,
**kwargs,
):
"""
temporal_attention_mask: attention mask specific for the temporal self-attention.
"""
d = hidden_states.shape[1]
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
if self.group_norm is not None:
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query_layer, key_full, value_full = self.prepare_qkv_full_and_cache(
hidden_states, kv_cache, pe_idx, update_idx
)
# [(n * hw * b), f, c] -> [(n * hw * b * head), f, c // head]
query_layer = self.reshape_heads_to_batch_dim(query_layer)
key_full = self.reshape_heads_to_batch_dim(key_full)
value_full = self.reshape_heads_to_batch_dim(value_full)
if temporal_attention_mask is not None:
q_size = query_layer.shape[1]
# [n, self.window_size] -> [n, hw, q_size, window_size]
temporal_attention_mask_ = temporal_attention_mask[:, None, None, :].repeat(1, self.h * self.w, q_size, 1)
temporal_attention_mask_ = rearrange(temporal_attention_mask_, "n hw Q KV -> (n hw) Q KV")
temporal_attention_mask_ = temporal_attention_mask_.repeat_interleave(self.heads, dim=0)
else:
temporal_attention_mask_ = None
# attention, what we cannot get enough of
if hasattr(F, "scaled_dot_product_attention"):
hidden_states = self._memory_efficient_attention_pt20(
query_layer, key_full, value_full, attention_mask=temporal_attention_mask_
)
elif self._use_memory_efficient_attention_xformers:
hidden_states = self._memory_efficient_attention_xformers(
query_layer, key_full, value_full, attention_mask=temporal_attention_mask_
)
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
hidden_states = hidden_states.to(query_layer.dtype)
else:
hidden_states = self._attention(query_layer, key_full, value_full, temporal_attention_mask_)
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
return hidden_states