SunderAli17 commited on
Commit
a9823b9
·
verified ·
1 Parent(s): 7fe60bd

Create ip_adapter/attention_processor.py

Browse files
module/ip_adapter/attention_processor.py ADDED
@@ -0,0 +1,1467 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ class AdaLayerNorm(nn.Module):
7
+ def __init__(self, embedding_dim: int, time_embedding_dim: int = None):
8
+ super().__init__()
9
+
10
+ if time_embedding_dim is None:
11
+ time_embedding_dim = embedding_dim
12
+
13
+ self.silu = nn.SiLU()
14
+ self.linear = nn.Linear(time_embedding_dim, 2 * embedding_dim, bias=True)
15
+ nn.init.zeros_(self.linear.weight)
16
+ nn.init.zeros_(self.linear.bias)
17
+
18
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
19
+
20
+ def forward(
21
+ self, x: torch.Tensor, timestep_embedding: torch.Tensor
22
+ ):
23
+ emb = self.linear(self.silu(timestep_embedding))
24
+ shift, scale = emb.view(len(x), 1, -1).chunk(2, dim=-1)
25
+ x = self.norm(x) * (1 + scale) + shift
26
+ return x
27
+
28
+
29
+ class AttnProcessor(nn.Module):
30
+ r"""
31
+ Default processor for performing attention-related computations.
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ hidden_size=None,
37
+ cross_attention_dim=None,
38
+ ):
39
+ super().__init__()
40
+
41
+ def __call__(
42
+ self,
43
+ attn,
44
+ hidden_states,
45
+ encoder_hidden_states=None,
46
+ attention_mask=None,
47
+ temb=None,
48
+ ):
49
+ residual = hidden_states
50
+
51
+ if attn.spatial_norm is not None:
52
+ hidden_states = attn.spatial_norm(hidden_states, temb)
53
+
54
+ input_ndim = hidden_states.ndim
55
+
56
+ if input_ndim == 4:
57
+ batch_size, channel, height, width = hidden_states.shape
58
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
59
+
60
+ batch_size, sequence_length, _ = (
61
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
62
+ )
63
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
64
+
65
+ if attn.group_norm is not None:
66
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
67
+
68
+ query = attn.to_q(hidden_states)
69
+
70
+ if encoder_hidden_states is None:
71
+ encoder_hidden_states = hidden_states
72
+ elif attn.norm_cross:
73
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
74
+
75
+ key = attn.to_k(encoder_hidden_states)
76
+ value = attn.to_v(encoder_hidden_states)
77
+
78
+ query = attn.head_to_batch_dim(query)
79
+ key = attn.head_to_batch_dim(key)
80
+ value = attn.head_to_batch_dim(value)
81
+
82
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
83
+ hidden_states = torch.bmm(attention_probs, value)
84
+ hidden_states = attn.batch_to_head_dim(hidden_states)
85
+
86
+ # linear proj
87
+ hidden_states = attn.to_out[0](hidden_states)
88
+ # dropout
89
+ hidden_states = attn.to_out[1](hidden_states)
90
+
91
+ if input_ndim == 4:
92
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
93
+
94
+ if attn.residual_connection:
95
+ hidden_states = hidden_states + residual
96
+
97
+ hidden_states = hidden_states / attn.rescale_output_factor
98
+
99
+ return hidden_states
100
+
101
+
102
+ class IPAttnProcessor(nn.Module):
103
+ r"""
104
+ Attention processor for IP-Adapater.
105
+ Args:
106
+ hidden_size (`int`):
107
+ The hidden size of the attention layer.
108
+ cross_attention_dim (`int`):
109
+ The number of channels in the `encoder_hidden_states`.
110
+ scale (`float`, defaults to 1.0):
111
+ the weight scale of image prompt.
112
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
113
+ The context length of the image features.
114
+ """
115
+
116
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
117
+ super().__init__()
118
+
119
+ self.hidden_size = hidden_size
120
+ self.cross_attention_dim = cross_attention_dim
121
+ self.scale = scale
122
+ self.num_tokens = num_tokens
123
+
124
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
125
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
126
+
127
+ def __call__(
128
+ self,
129
+ attn,
130
+ hidden_states,
131
+ encoder_hidden_states=None,
132
+ attention_mask=None,
133
+ temb=None,
134
+ ):
135
+ residual = hidden_states
136
+
137
+ if attn.spatial_norm is not None:
138
+ hidden_states = attn.spatial_norm(hidden_states, temb)
139
+
140
+ input_ndim = hidden_states.ndim
141
+
142
+ if input_ndim == 4:
143
+ batch_size, channel, height, width = hidden_states.shape
144
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
145
+
146
+ batch_size, sequence_length, _ = (
147
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
148
+ )
149
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
150
+
151
+ if attn.group_norm is not None:
152
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
153
+
154
+ query = attn.to_q(hidden_states)
155
+
156
+ if encoder_hidden_states is None:
157
+ encoder_hidden_states = hidden_states
158
+ else:
159
+ # get encoder_hidden_states, ip_hidden_states
160
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
161
+ encoder_hidden_states, ip_hidden_states = (
162
+ encoder_hidden_states[:, :end_pos, :],
163
+ encoder_hidden_states[:, end_pos:, :],
164
+ )
165
+ if attn.norm_cross:
166
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
167
+
168
+ key = attn.to_k(encoder_hidden_states)
169
+ value = attn.to_v(encoder_hidden_states)
170
+
171
+ query = attn.head_to_batch_dim(query)
172
+ key = attn.head_to_batch_dim(key)
173
+ value = attn.head_to_batch_dim(value)
174
+
175
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
176
+ hidden_states = torch.bmm(attention_probs, value)
177
+ hidden_states = attn.batch_to_head_dim(hidden_states)
178
+
179
+ # for ip-adapter
180
+ ip_key = self.to_k_ip(ip_hidden_states)
181
+ ip_value = self.to_v_ip(ip_hidden_states)
182
+
183
+ ip_key = attn.head_to_batch_dim(ip_key)
184
+ ip_value = attn.head_to_batch_dim(ip_value)
185
+
186
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
187
+ ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
188
+ ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
189
+
190
+ hidden_states = hidden_states + self.scale * ip_hidden_states
191
+
192
+ # linear proj
193
+ hidden_states = attn.to_out[0](hidden_states)
194
+ # dropout
195
+ hidden_states = attn.to_out[1](hidden_states)
196
+
197
+ if input_ndim == 4:
198
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
199
+
200
+ if attn.residual_connection:
201
+ hidden_states = hidden_states + residual
202
+
203
+ hidden_states = hidden_states / attn.rescale_output_factor
204
+
205
+ return hidden_states
206
+
207
+
208
+ class TA_IPAttnProcessor(nn.Module):
209
+ r"""
210
+ Attention processor for IP-Adapater.
211
+ Args:
212
+ hidden_size (`int`):
213
+ The hidden size of the attention layer.
214
+ cross_attention_dim (`int`):
215
+ The number of channels in the `encoder_hidden_states`.
216
+ scale (`float`, defaults to 1.0):
217
+ the weight scale of image prompt.
218
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
219
+ The context length of the image features.
220
+ """
221
+
222
+ def __init__(self, hidden_size, cross_attention_dim=None, time_embedding_dim: int = None, scale=1.0, num_tokens=4):
223
+ super().__init__()
224
+
225
+ self.hidden_size = hidden_size
226
+ self.cross_attention_dim = cross_attention_dim
227
+ self.scale = scale
228
+ self.num_tokens = num_tokens
229
+
230
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
231
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
232
+
233
+ self.ln_k_ip = AdaLayerNorm(hidden_size, time_embedding_dim)
234
+ self.ln_v_ip = AdaLayerNorm(hidden_size, time_embedding_dim)
235
+
236
+ def __call__(
237
+ self,
238
+ attn,
239
+ hidden_states,
240
+ encoder_hidden_states=None,
241
+ attention_mask=None,
242
+ temb=None,
243
+ ):
244
+ assert temb is not None, "Timestep embedding is needed for a time-aware attention processor."
245
+
246
+ residual = hidden_states
247
+
248
+ if attn.spatial_norm is not None:
249
+ hidden_states = attn.spatial_norm(hidden_states, temb)
250
+
251
+ input_ndim = hidden_states.ndim
252
+
253
+ if input_ndim == 4:
254
+ batch_size, channel, height, width = hidden_states.shape
255
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
256
+
257
+ batch_size, sequence_length, _ = (
258
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
259
+ )
260
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
261
+
262
+ if attn.group_norm is not None:
263
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
264
+
265
+ query = attn.to_q(hidden_states)
266
+
267
+ if encoder_hidden_states is None:
268
+ encoder_hidden_states = hidden_states
269
+ else:
270
+ # get encoder_hidden_states, ip_hidden_states
271
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
272
+ encoder_hidden_states, ip_hidden_states = (
273
+ encoder_hidden_states[:, :end_pos, :],
274
+ encoder_hidden_states[:, end_pos:, :],
275
+ )
276
+ if attn.norm_cross:
277
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
278
+
279
+ key = attn.to_k(encoder_hidden_states)
280
+ value = attn.to_v(encoder_hidden_states)
281
+
282
+ query = attn.head_to_batch_dim(query)
283
+ key = attn.head_to_batch_dim(key)
284
+ value = attn.head_to_batch_dim(value)
285
+
286
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
287
+ hidden_states = torch.bmm(attention_probs, value)
288
+ hidden_states = attn.batch_to_head_dim(hidden_states)
289
+
290
+ # for ip-adapter
291
+ ip_key = self.to_k_ip(ip_hidden_states)
292
+ ip_value = self.to_v_ip(ip_hidden_states)
293
+
294
+ # time-dependent adaLN
295
+ ip_key = self.ln_k_ip(ip_key, temb)
296
+ ip_value = self.ln_v_ip(ip_value, temb)
297
+
298
+ ip_key = attn.head_to_batch_dim(ip_key)
299
+ ip_value = attn.head_to_batch_dim(ip_value)
300
+
301
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
302
+ ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
303
+ ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
304
+
305
+ hidden_states = hidden_states + self.scale * ip_hidden_states
306
+
307
+ # linear proj
308
+ hidden_states = attn.to_out[0](hidden_states)
309
+ # dropout
310
+ hidden_states = attn.to_out[1](hidden_states)
311
+
312
+ if input_ndim == 4:
313
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
314
+
315
+ if attn.residual_connection:
316
+ hidden_states = hidden_states + residual
317
+
318
+ hidden_states = hidden_states / attn.rescale_output_factor
319
+
320
+ return hidden_states
321
+
322
+
323
+ class AttnProcessor2_0(torch.nn.Module):
324
+ r"""
325
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
326
+ """
327
+
328
+ def __init__(
329
+ self,
330
+ hidden_size=None,
331
+ cross_attention_dim=None,
332
+ ):
333
+ super().__init__()
334
+ if not hasattr(F, "scaled_dot_product_attention"):
335
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
336
+
337
+ def __call__(
338
+ self,
339
+ attn,
340
+ hidden_states,
341
+ encoder_hidden_states=None,
342
+ attention_mask=None,
343
+ external_kv=None,
344
+ temb=None,
345
+ ):
346
+ residual = hidden_states
347
+
348
+ if attn.spatial_norm is not None:
349
+ hidden_states = attn.spatial_norm(hidden_states, temb)
350
+
351
+ input_ndim = hidden_states.ndim
352
+
353
+ if input_ndim == 4:
354
+ batch_size, channel, height, width = hidden_states.shape
355
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
356
+
357
+ batch_size, sequence_length, _ = (
358
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
359
+ )
360
+
361
+ if attention_mask is not None:
362
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
363
+ # scaled_dot_product_attention expects attention_mask shape to be
364
+ # (batch, heads, source_length, target_length)
365
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
366
+
367
+ if attn.group_norm is not None:
368
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
369
+
370
+ query = attn.to_q(hidden_states)
371
+
372
+ if encoder_hidden_states is None:
373
+ encoder_hidden_states = hidden_states
374
+ elif attn.norm_cross:
375
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
376
+
377
+ key = attn.to_k(encoder_hidden_states)
378
+ value = attn.to_v(encoder_hidden_states)
379
+
380
+ if external_kv:
381
+ key = torch.cat([key, external_kv.k], axis=1)
382
+ value = torch.cat([value, external_kv.v], axis=1)
383
+
384
+ inner_dim = key.shape[-1]
385
+ head_dim = inner_dim // attn.heads
386
+
387
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
388
+
389
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
390
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
391
+
392
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
393
+ # TODO: add support for attn.scale when we move to Torch 2.1
394
+ hidden_states = F.scaled_dot_product_attention(
395
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
396
+ )
397
+
398
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
399
+ hidden_states = hidden_states.to(query.dtype)
400
+
401
+ # linear proj
402
+ hidden_states = attn.to_out[0](hidden_states)
403
+ # dropout
404
+ hidden_states = attn.to_out[1](hidden_states)
405
+
406
+ if input_ndim == 4:
407
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
408
+
409
+ if attn.residual_connection:
410
+ hidden_states = hidden_states + residual
411
+
412
+ hidden_states = hidden_states / attn.rescale_output_factor
413
+
414
+ return hidden_states
415
+
416
+
417
+ class split_AttnProcessor2_0(torch.nn.Module):
418
+ r"""
419
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
420
+ """
421
+
422
+ def __init__(
423
+ self,
424
+ hidden_size=None,
425
+ cross_attention_dim=None,
426
+ time_embedding_dim=None,
427
+ ):
428
+ super().__init__()
429
+ if not hasattr(F, "scaled_dot_product_attention"):
430
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
431
+
432
+ def __call__(
433
+ self,
434
+ attn,
435
+ hidden_states,
436
+ encoder_hidden_states=None,
437
+ attention_mask=None,
438
+ external_kv=None,
439
+ temb=None,
440
+ cat_dim=-2,
441
+ original_shape=None,
442
+ ):
443
+ residual = hidden_states
444
+
445
+ if attn.spatial_norm is not None:
446
+ hidden_states = attn.spatial_norm(hidden_states, temb)
447
+
448
+ input_ndim = hidden_states.ndim
449
+
450
+ if input_ndim == 4:
451
+ # 2d to sequence.
452
+ height, width = hidden_states.shape[-2:]
453
+ if cat_dim==-2 or cat_dim==2:
454
+ hidden_states_0 = hidden_states[:, :, :height//2, :]
455
+ hidden_states_1 = hidden_states[:, :, -(height//2):, :]
456
+ elif cat_dim==-1 or cat_dim==3:
457
+ hidden_states_0 = hidden_states[:, :, :, :width//2]
458
+ hidden_states_1 = hidden_states[:, :, :, -(width//2):]
459
+ batch_size, channel, height, width = hidden_states_0.shape
460
+ hidden_states_0 = hidden_states_0.view(batch_size, channel, height * width).transpose(1, 2)
461
+ hidden_states_1 = hidden_states_1.view(batch_size, channel, height * width).transpose(1, 2)
462
+ else:
463
+ # directly split sqeuence according to concat dim.
464
+ single_dim = original_shape[2] if cat_dim==-2 or cat_dim==2 else original_shape[1]
465
+ hidden_states_0 = hidden_states[:, :single_dim*single_dim,:]
466
+ hidden_states_1 = hidden_states[:, single_dim*(single_dim+1):,:]
467
+
468
+ hidden_states = torch.cat([hidden_states_0, hidden_states_1], dim=1)
469
+ batch_size, sequence_length, _ = (
470
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
471
+ )
472
+
473
+ if attention_mask is not None:
474
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
475
+ # scaled_dot_product_attention expects attention_mask shape to be
476
+ # (batch, heads, source_length, target_length)
477
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
478
+
479
+ if attn.group_norm is not None:
480
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
481
+
482
+ query = attn.to_q(hidden_states)
483
+ key = attn.to_k(hidden_states)
484
+ value = attn.to_v(hidden_states)
485
+
486
+ if external_kv:
487
+ key = torch.cat([key, external_kv.k], dim=1)
488
+ value = torch.cat([value, external_kv.v], dim=1)
489
+
490
+ inner_dim = key.shape[-1]
491
+ head_dim = inner_dim // attn.heads
492
+
493
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
494
+
495
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
496
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
497
+
498
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
499
+ # TODO: add support for attn.scale when we move to Torch 2.1
500
+ hidden_states = F.scaled_dot_product_attention(
501
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
502
+ )
503
+
504
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
505
+ hidden_states = hidden_states.to(query.dtype)
506
+
507
+ # linear proj
508
+ hidden_states = attn.to_out[0](hidden_states)
509
+ # dropout
510
+ hidden_states = attn.to_out[1](hidden_states)
511
+
512
+ # spatially split.
513
+ hidden_states_0, hidden_states_1 = hidden_states.chunk(2, dim=1)
514
+
515
+ if input_ndim == 4:
516
+ hidden_states_0 = hidden_states_0.transpose(-1, -2).reshape(batch_size, channel, height, width)
517
+ hidden_states_1 = hidden_states_1.transpose(-1, -2).reshape(batch_size, channel, height, width)
518
+
519
+ if cat_dim==-2 or cat_dim==2:
520
+ hidden_states_pad = torch.zeros(batch_size, channel, 1, width)
521
+ elif cat_dim==-1 or cat_dim==3:
522
+ hidden_states_pad = torch.zeros(batch_size, channel, height, 1)
523
+ hidden_states_pad = hidden_states_pad.to(hidden_states_0.device, dtype=hidden_states_0.dtype)
524
+ hidden_states = torch.cat([hidden_states_0, hidden_states_pad, hidden_states_1], dim=cat_dim)
525
+ assert hidden_states.shape == residual.shape, f"{hidden_states.shape} != {residual.shape}"
526
+ else:
527
+ batch_size, sequence_length, inner_dim = hidden_states.shape
528
+ hidden_states_pad = torch.zeros(batch_size, single_dim, inner_dim)
529
+ hidden_states_pad = hidden_states_pad.to(hidden_states_0.device, dtype=hidden_states_0.dtype)
530
+ hidden_states = torch.cat([hidden_states_0, hidden_states_pad, hidden_states_1], dim=1)
531
+ assert hidden_states.shape == residual.shape, f"{hidden_states.shape} != {residual.shape}"
532
+
533
+ if attn.residual_connection:
534
+ hidden_states = hidden_states + residual
535
+
536
+ hidden_states = hidden_states / attn.rescale_output_factor
537
+
538
+ return hidden_states
539
+
540
+
541
+ class sep_split_AttnProcessor2_0(torch.nn.Module):
542
+ r"""
543
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
544
+ """
545
+
546
+ def __init__(
547
+ self,
548
+ hidden_size=None,
549
+ cross_attention_dim=None,
550
+ time_embedding_dim=None,
551
+ ):
552
+ super().__init__()
553
+ if not hasattr(F, "scaled_dot_product_attention"):
554
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
555
+ self.ln_k_ref = AdaLayerNorm(hidden_size, time_embedding_dim)
556
+ self.ln_v_ref = AdaLayerNorm(hidden_size, time_embedding_dim)
557
+ # self.hidden_size = hidden_size
558
+ # self.cross_attention_dim = cross_attention_dim
559
+ # self.scale = scale
560
+ # self.num_tokens = num_tokens
561
+
562
+ # self.to_q_ref = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
563
+ # self.to_k_ref = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
564
+ # self.to_v_ref = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
565
+
566
+ def __call__(
567
+ self,
568
+ attn,
569
+ hidden_states,
570
+ encoder_hidden_states=None,
571
+ attention_mask=None,
572
+ external_kv=None,
573
+ temb=None,
574
+ cat_dim=-2,
575
+ original_shape=None,
576
+ ref_scale=1.0,
577
+ ):
578
+ residual = hidden_states
579
+
580
+ if attn.spatial_norm is not None:
581
+ hidden_states = attn.spatial_norm(hidden_states, temb)
582
+
583
+ input_ndim = hidden_states.ndim
584
+
585
+ if input_ndim == 4:
586
+ # 2d to sequence.
587
+ height, width = hidden_states.shape[-2:]
588
+ if cat_dim==-2 or cat_dim==2:
589
+ hidden_states_0 = hidden_states[:, :, :height//2, :]
590
+ hidden_states_1 = hidden_states[:, :, -(height//2):, :]
591
+ elif cat_dim==-1 or cat_dim==3:
592
+ hidden_states_0 = hidden_states[:, :, :, :width//2]
593
+ hidden_states_1 = hidden_states[:, :, :, -(width//2):]
594
+ batch_size, channel, height, width = hidden_states_0.shape
595
+ hidden_states_0 = hidden_states_0.view(batch_size, channel, height * width).transpose(1, 2)
596
+ hidden_states_1 = hidden_states_1.view(batch_size, channel, height * width).transpose(1, 2)
597
+ else:
598
+ # directly split sqeuence according to concat dim.
599
+ single_dim = original_shape[2] if cat_dim==-2 or cat_dim==2 else original_shape[1]
600
+ hidden_states_0 = hidden_states[:, :single_dim*single_dim,:]
601
+ hidden_states_1 = hidden_states[:, single_dim*(single_dim+1):,:]
602
+
603
+ batch_size, sequence_length, _ = (
604
+ hidden_states_0.shape if encoder_hidden_states is None else encoder_hidden_states.shape
605
+ )
606
+
607
+ if attention_mask is not None:
608
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
609
+ # scaled_dot_product_attention expects attention_mask shape to be
610
+ # (batch, heads, source_length, target_length)
611
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
612
+
613
+ if attn.group_norm is not None:
614
+ hidden_states_0 = attn.group_norm(hidden_states_0.transpose(1, 2)).transpose(1, 2)
615
+ hidden_states_1 = attn.group_norm(hidden_states_1.transpose(1, 2)).transpose(1, 2)
616
+
617
+ query_0 = attn.to_q(hidden_states_0)
618
+ query_1 = attn.to_q(hidden_states_1)
619
+ key_0 = attn.to_k(hidden_states_0)
620
+ key_1 = attn.to_k(hidden_states_1)
621
+ value_0 = attn.to_v(hidden_states_0)
622
+ value_1 = attn.to_v(hidden_states_1)
623
+
624
+ # time-dependent adaLN
625
+ key_1 = self.ln_k_ref(key_1, temb)
626
+ value_1 = self.ln_v_ref(value_1, temb)
627
+
628
+ if external_kv:
629
+ key_1 = torch.cat([key_1, external_kv.k], dim=1)
630
+ value_1 = torch.cat([value_1, external_kv.v], dim=1)
631
+
632
+ inner_dim = key_0.shape[-1]
633
+ head_dim = inner_dim // attn.heads
634
+
635
+ query_0 = query_0.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
636
+ query_1 = query_1.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
637
+ key_0 = key_0.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
638
+ key_1 = key_1.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
639
+ value_0 = value_0.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
640
+ value_1 = value_1.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
641
+
642
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
643
+ # TODO: add support for attn.scale when we move to Torch 2.1
644
+ hidden_states_0 = F.scaled_dot_product_attention(
645
+ query_0, key_0, value_0, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
646
+ )
647
+ hidden_states_1 = F.scaled_dot_product_attention(
648
+ query_1, key_1, value_1, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
649
+ )
650
+
651
+ # cross-attn
652
+ _hidden_states_0 = F.scaled_dot_product_attention(
653
+ query_0, key_1, value_1, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
654
+ )
655
+ hidden_states_0 = hidden_states_0 + ref_scale * _hidden_states_0 * 10
656
+
657
+ # TODO: drop this cross-attn
658
+ _hidden_states_1 = F.scaled_dot_product_attention(
659
+ query_1, key_0, value_0, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
660
+ )
661
+ hidden_states_1 = hidden_states_1 + ref_scale * _hidden_states_1
662
+
663
+ hidden_states_0 = hidden_states_0.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
664
+ hidden_states_1 = hidden_states_1.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
665
+ hidden_states_0 = hidden_states_0.to(query_0.dtype)
666
+ hidden_states_1 = hidden_states_1.to(query_1.dtype)
667
+
668
+
669
+ # linear proj
670
+ hidden_states_0 = attn.to_out[0](hidden_states_0)
671
+ hidden_states_1 = attn.to_out[0](hidden_states_1)
672
+ # dropout
673
+ hidden_states_0 = attn.to_out[1](hidden_states_0)
674
+ hidden_states_1 = attn.to_out[1](hidden_states_1)
675
+
676
+
677
+ if input_ndim == 4:
678
+ hidden_states_0 = hidden_states_0.transpose(-1, -2).reshape(batch_size, channel, height, width)
679
+ hidden_states_1 = hidden_states_1.transpose(-1, -2).reshape(batch_size, channel, height, width)
680
+
681
+ if cat_dim==-2 or cat_dim==2:
682
+ hidden_states_pad = torch.zeros(batch_size, channel, 1, width)
683
+ elif cat_dim==-1 or cat_dim==3:
684
+ hidden_states_pad = torch.zeros(batch_size, channel, height, 1)
685
+ hidden_states_pad = hidden_states_pad.to(hidden_states_0.device, dtype=hidden_states_0.dtype)
686
+ hidden_states = torch.cat([hidden_states_0, hidden_states_pad, hidden_states_1], dim=cat_dim)
687
+ assert hidden_states.shape == residual.shape, f"{hidden_states.shape} != {residual.shape}"
688
+ else:
689
+ batch_size, sequence_length, inner_dim = hidden_states.shape
690
+ hidden_states_pad = torch.zeros(batch_size, single_dim, inner_dim)
691
+ hidden_states_pad = hidden_states_pad.to(hidden_states_0.device, dtype=hidden_states_0.dtype)
692
+ hidden_states = torch.cat([hidden_states_0, hidden_states_pad, hidden_states_1], dim=1)
693
+ assert hidden_states.shape == residual.shape, f"{hidden_states.shape} != {residual.shape}"
694
+
695
+ if attn.residual_connection:
696
+ hidden_states = hidden_states + residual
697
+
698
+ hidden_states = hidden_states / attn.rescale_output_factor
699
+
700
+ return hidden_states
701
+
702
+
703
+ class AdditiveKV_AttnProcessor2_0(torch.nn.Module):
704
+ r"""
705
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
706
+ """
707
+
708
+ def __init__(
709
+ self,
710
+ hidden_size: int = None,
711
+ cross_attention_dim: int = None,
712
+ time_embedding_dim: int = None,
713
+ additive_scale: float = 1.0,
714
+ ):
715
+ super().__init__()
716
+ if not hasattr(F, "scaled_dot_product_attention"):
717
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
718
+ self.additive_scale = additive_scale
719
+
720
+ def __call__(
721
+ self,
722
+ attn,
723
+ hidden_states,
724
+ encoder_hidden_states=None,
725
+ external_kv=None,
726
+ attention_mask=None,
727
+ temb=None,
728
+ ):
729
+ assert temb is not None, "Timestep embedding is needed for a time-aware attention processor."
730
+
731
+ residual = hidden_states
732
+
733
+ if attn.spatial_norm is not None:
734
+ hidden_states = attn.spatial_norm(hidden_states, temb)
735
+
736
+ input_ndim = hidden_states.ndim
737
+
738
+ if input_ndim == 4:
739
+ batch_size, channel, height, width = hidden_states.shape
740
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
741
+
742
+ batch_size, sequence_length, _ = (
743
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
744
+ )
745
+
746
+ if attention_mask is not None:
747
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
748
+ # scaled_dot_product_attention expects attention_mask shape to be
749
+ # (batch, heads, source_length, target_length)
750
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
751
+
752
+ if attn.group_norm is not None:
753
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
754
+
755
+ query = attn.to_q(hidden_states)
756
+
757
+ if encoder_hidden_states is None:
758
+ encoder_hidden_states = hidden_states
759
+ elif attn.norm_cross:
760
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
761
+
762
+ key = attn.to_k(encoder_hidden_states)
763
+ value = attn.to_v(encoder_hidden_states)
764
+
765
+ inner_dim = key.shape[-1]
766
+ head_dim = inner_dim // attn.heads
767
+
768
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
769
+
770
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
771
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
772
+
773
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
774
+ # TODO: add support for attn.scale when we move to Torch 2.1
775
+ hidden_states = F.scaled_dot_product_attention(
776
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
777
+ )
778
+
779
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
780
+
781
+ if external_kv:
782
+ key = external_kv.k
783
+ value = external_kv.v
784
+
785
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
786
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
787
+
788
+ external_attn_output = F.scaled_dot_product_attention(
789
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
790
+ )
791
+
792
+ external_attn_output = external_attn_output.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
793
+ hidden_states = hidden_states + self.additive_scale * external_attn_output
794
+
795
+ hidden_states = hidden_states.to(query.dtype)
796
+
797
+ # linear proj
798
+ hidden_states = attn.to_out[0](hidden_states)
799
+ # dropout
800
+ hidden_states = attn.to_out[1](hidden_states)
801
+
802
+ if input_ndim == 4:
803
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
804
+
805
+ if attn.residual_connection:
806
+ hidden_states = hidden_states + residual
807
+
808
+ hidden_states = hidden_states / attn.rescale_output_factor
809
+
810
+ return hidden_states
811
+
812
+
813
+ class TA_AdditiveKV_AttnProcessor2_0(torch.nn.Module):
814
+ r"""
815
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
816
+ """
817
+
818
+ def __init__(
819
+ self,
820
+ hidden_size: int = None,
821
+ cross_attention_dim: int = None,
822
+ time_embedding_dim: int = None,
823
+ additive_scale: float = 1.0,
824
+ ):
825
+ super().__init__()
826
+ if not hasattr(F, "scaled_dot_product_attention"):
827
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
828
+ self.ln_k = AdaLayerNorm(hidden_size, time_embedding_dim)
829
+ self.ln_v = AdaLayerNorm(hidden_size, time_embedding_dim)
830
+ self.additive_scale = additive_scale
831
+
832
+ def __call__(
833
+ self,
834
+ attn,
835
+ hidden_states,
836
+ encoder_hidden_states=None,
837
+ external_kv=None,
838
+ attention_mask=None,
839
+ temb=None,
840
+ ):
841
+ assert temb is not None, "Timestep embedding is needed for a time-aware attention processor."
842
+
843
+ residual = hidden_states
844
+
845
+ if attn.spatial_norm is not None:
846
+ hidden_states = attn.spatial_norm(hidden_states, temb)
847
+
848
+ input_ndim = hidden_states.ndim
849
+
850
+ if input_ndim == 4:
851
+ batch_size, channel, height, width = hidden_states.shape
852
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
853
+
854
+ batch_size, sequence_length, _ = (
855
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
856
+ )
857
+
858
+ if attention_mask is not None:
859
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
860
+ # scaled_dot_product_attention expects attention_mask shape to be
861
+ # (batch, heads, source_length, target_length)
862
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
863
+
864
+ if attn.group_norm is not None:
865
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
866
+
867
+ query = attn.to_q(hidden_states)
868
+
869
+ if encoder_hidden_states is None:
870
+ encoder_hidden_states = hidden_states
871
+ elif attn.norm_cross:
872
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
873
+
874
+ key = attn.to_k(encoder_hidden_states)
875
+ value = attn.to_v(encoder_hidden_states)
876
+
877
+ inner_dim = key.shape[-1]
878
+ head_dim = inner_dim // attn.heads
879
+
880
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
881
+
882
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
883
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
884
+
885
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
886
+ # TODO: add support for attn.scale when we move to Torch 2.1
887
+ hidden_states = F.scaled_dot_product_attention(
888
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
889
+ )
890
+
891
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
892
+
893
+ if external_kv:
894
+ key = external_kv.k
895
+ value = external_kv.v
896
+
897
+ # time-dependent adaLN
898
+ key = self.ln_k(key, temb)
899
+ value = self.ln_v(value, temb)
900
+
901
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
902
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
903
+
904
+ external_attn_output = F.scaled_dot_product_attention(
905
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
906
+ )
907
+
908
+ external_attn_output = external_attn_output.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
909
+ hidden_states = hidden_states + self.additive_scale * external_attn_output
910
+
911
+ hidden_states = hidden_states.to(query.dtype)
912
+
913
+ # linear proj
914
+ hidden_states = attn.to_out[0](hidden_states)
915
+ # dropout
916
+ hidden_states = attn.to_out[1](hidden_states)
917
+
918
+ if input_ndim == 4:
919
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
920
+
921
+ if attn.residual_connection:
922
+ hidden_states = hidden_states + residual
923
+
924
+ hidden_states = hidden_states / attn.rescale_output_factor
925
+
926
+ return hidden_states
927
+
928
+
929
+ class IPAttnProcessor2_0(torch.nn.Module):
930
+ r"""
931
+ Attention processor for IP-Adapater for PyTorch 2.0.
932
+ Args:
933
+ hidden_size (`int`):
934
+ The hidden size of the attention layer.
935
+ cross_attention_dim (`int`):
936
+ The number of channels in the `encoder_hidden_states`.
937
+ scale (`float`, defaults to 1.0):
938
+ the weight scale of image prompt.
939
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
940
+ The context length of the image features.
941
+ """
942
+
943
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
944
+ super().__init__()
945
+
946
+ if not hasattr(F, "scaled_dot_product_attention"):
947
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
948
+
949
+ self.hidden_size = hidden_size
950
+ self.cross_attention_dim = cross_attention_dim
951
+ self.scale = scale
952
+ self.num_tokens = num_tokens
953
+
954
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
955
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
956
+
957
+ def __call__(
958
+ self,
959
+ attn,
960
+ hidden_states,
961
+ encoder_hidden_states=None,
962
+ attention_mask=None,
963
+ temb=None,
964
+ ):
965
+ residual = hidden_states
966
+
967
+ if attn.spatial_norm is not None:
968
+ hidden_states = attn.spatial_norm(hidden_states, temb)
969
+
970
+ input_ndim = hidden_states.ndim
971
+
972
+ if input_ndim == 4:
973
+ batch_size, channel, height, width = hidden_states.shape
974
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
975
+
976
+ if isinstance(encoder_hidden_states, tuple):
977
+ # FIXME: now hard coded to single image prompt.
978
+ batch_size, _, hid_dim = encoder_hidden_states[0].shape
979
+ ip_tokens = encoder_hidden_states[1][0]
980
+ encoder_hidden_states = torch.cat([encoder_hidden_states[0], ip_tokens], dim=1)
981
+
982
+ batch_size, sequence_length, _ = (
983
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
984
+ )
985
+
986
+ if attention_mask is not None:
987
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
988
+ # scaled_dot_product_attention expects attention_mask shape to be
989
+ # (batch, heads, source_length, target_length)
990
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
991
+
992
+ if attn.group_norm is not None:
993
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
994
+
995
+ query = attn.to_q(hidden_states)
996
+
997
+ if encoder_hidden_states is None:
998
+ encoder_hidden_states = hidden_states
999
+ else:
1000
+ # get encoder_hidden_states, ip_hidden_states
1001
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
1002
+ encoder_hidden_states, ip_hidden_states = (
1003
+ encoder_hidden_states[:, :end_pos, :],
1004
+ encoder_hidden_states[:, end_pos:, :],
1005
+ )
1006
+ if attn.norm_cross:
1007
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1008
+
1009
+ key = attn.to_k(encoder_hidden_states)
1010
+ value = attn.to_v(encoder_hidden_states)
1011
+
1012
+ inner_dim = key.shape[-1]
1013
+ head_dim = inner_dim // attn.heads
1014
+
1015
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1016
+
1017
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1018
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1019
+
1020
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
1021
+ # TODO: add support for attn.scale when we move to Torch 2.1
1022
+ hidden_states = F.scaled_dot_product_attention(
1023
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1024
+ )
1025
+
1026
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1027
+ hidden_states = hidden_states.to(query.dtype)
1028
+
1029
+ # for ip-adapter
1030
+ ip_key = self.to_k_ip(ip_hidden_states)
1031
+ ip_value = self.to_v_ip(ip_hidden_states)
1032
+
1033
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1034
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1035
+
1036
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
1037
+ # TODO: add support for attn.scale when we move to Torch 2.1
1038
+ ip_hidden_states = F.scaled_dot_product_attention(
1039
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
1040
+ )
1041
+
1042
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1043
+ ip_hidden_states = ip_hidden_states.to(query.dtype)
1044
+
1045
+ hidden_states = hidden_states + self.scale * ip_hidden_states
1046
+
1047
+ # linear proj
1048
+ hidden_states = attn.to_out[0](hidden_states)
1049
+ # dropout
1050
+ hidden_states = attn.to_out[1](hidden_states)
1051
+
1052
+ if input_ndim == 4:
1053
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1054
+
1055
+ if attn.residual_connection:
1056
+ hidden_states = hidden_states + residual
1057
+
1058
+ hidden_states = hidden_states / attn.rescale_output_factor
1059
+
1060
+ return hidden_states
1061
+
1062
+
1063
+ class TA_IPAttnProcessor2_0(torch.nn.Module):
1064
+ r"""
1065
+ Attention processor for IP-Adapater for PyTorch 2.0.
1066
+ Args:
1067
+ hidden_size (`int`):
1068
+ The hidden size of the attention layer.
1069
+ cross_attention_dim (`int`):
1070
+ The number of channels in the `encoder_hidden_states`.
1071
+ scale (`float`, defaults to 1.0):
1072
+ the weight scale of image prompt.
1073
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
1074
+ The context length of the image features.
1075
+ """
1076
+
1077
+ def __init__(self, hidden_size, cross_attention_dim=None, time_embedding_dim: int = None, scale=1.0, num_tokens=4):
1078
+ super().__init__()
1079
+
1080
+ if not hasattr(F, "scaled_dot_product_attention"):
1081
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1082
+
1083
+ self.hidden_size = hidden_size
1084
+ self.cross_attention_dim = cross_attention_dim
1085
+ self.scale = scale
1086
+ self.num_tokens = num_tokens
1087
+
1088
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
1089
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
1090
+ self.ln_k_ip = AdaLayerNorm(hidden_size, time_embedding_dim)
1091
+ self.ln_v_ip = AdaLayerNorm(hidden_size, time_embedding_dim)
1092
+
1093
+ def __call__(
1094
+ self,
1095
+ attn,
1096
+ hidden_states,
1097
+ encoder_hidden_states=None,
1098
+ attention_mask=None,
1099
+ external_kv=None,
1100
+ temb=None,
1101
+ ):
1102
+ assert temb is not None, "Timestep embedding is needed for a time-aware attention processor."
1103
+
1104
+ residual = hidden_states
1105
+
1106
+ if attn.spatial_norm is not None:
1107
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1108
+
1109
+ input_ndim = hidden_states.ndim
1110
+
1111
+ if input_ndim == 4:
1112
+ batch_size, channel, height, width = hidden_states.shape
1113
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1114
+
1115
+ if not isinstance(encoder_hidden_states, tuple):
1116
+ # get encoder_hidden_states, ip_hidden_states
1117
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
1118
+ encoder_hidden_states, ip_hidden_states = (
1119
+ encoder_hidden_states[:, :end_pos, :],
1120
+ encoder_hidden_states[:, end_pos:, :],
1121
+ )
1122
+ else:
1123
+ # FIXME: now hard coded to single image prompt.
1124
+ batch_size, _, hid_dim = encoder_hidden_states[0].shape
1125
+ ip_hidden_states = encoder_hidden_states[1][0]
1126
+ encoder_hidden_states = encoder_hidden_states[0]
1127
+ batch_size, sequence_length, _ = (
1128
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1129
+ )
1130
+
1131
+ if attention_mask is not None:
1132
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1133
+ # scaled_dot_product_attention expects attention_mask shape to be
1134
+ # (batch, heads, source_length, target_length)
1135
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
1136
+
1137
+ if attn.group_norm is not None:
1138
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1139
+
1140
+ query = attn.to_q(hidden_states)
1141
+
1142
+ if encoder_hidden_states is None:
1143
+ encoder_hidden_states = hidden_states
1144
+ else:
1145
+ if attn.norm_cross:
1146
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1147
+
1148
+ key = attn.to_k(encoder_hidden_states)
1149
+ value = attn.to_v(encoder_hidden_states)
1150
+
1151
+ if external_kv:
1152
+ key = torch.cat([key, external_kv.k], axis=1)
1153
+ value = torch.cat([value, external_kv.v], axis=1)
1154
+
1155
+ inner_dim = key.shape[-1]
1156
+ head_dim = inner_dim // attn.heads
1157
+
1158
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1159
+
1160
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1161
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1162
+
1163
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
1164
+ # TODO: add support for attn.scale when we move to Torch 2.1
1165
+ hidden_states = F.scaled_dot_product_attention(
1166
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1167
+ )
1168
+
1169
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1170
+ hidden_states = hidden_states.to(query.dtype)
1171
+
1172
+ # for ip-adapter
1173
+ ip_key = self.to_k_ip(ip_hidden_states)
1174
+ ip_value = self.to_v_ip(ip_hidden_states)
1175
+
1176
+ # time-dependent adaLN
1177
+ ip_key = self.ln_k_ip(ip_key, temb)
1178
+ ip_value = self.ln_v_ip(ip_value, temb)
1179
+
1180
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1181
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1182
+
1183
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
1184
+ # TODO: add support for attn.scale when we move to Torch 2.1
1185
+ ip_hidden_states = F.scaled_dot_product_attention(
1186
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
1187
+ )
1188
+
1189
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1190
+ ip_hidden_states = ip_hidden_states.to(query.dtype)
1191
+
1192
+ hidden_states = hidden_states + self.scale * ip_hidden_states
1193
+
1194
+ # linear proj
1195
+ hidden_states = attn.to_out[0](hidden_states)
1196
+ # dropout
1197
+ hidden_states = attn.to_out[1](hidden_states)
1198
+
1199
+ if input_ndim == 4:
1200
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1201
+
1202
+ if attn.residual_connection:
1203
+ hidden_states = hidden_states + residual
1204
+
1205
+ hidden_states = hidden_states / attn.rescale_output_factor
1206
+
1207
+ return hidden_states
1208
+
1209
+
1210
+ ## for controlnet
1211
+ class CNAttnProcessor:
1212
+ r"""
1213
+ Default processor for performing attention-related computations.
1214
+ """
1215
+
1216
+ def __init__(self, num_tokens=4):
1217
+ self.num_tokens = num_tokens
1218
+
1219
+ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
1220
+ residual = hidden_states
1221
+
1222
+ if attn.spatial_norm is not None:
1223
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1224
+
1225
+ input_ndim = hidden_states.ndim
1226
+
1227
+ if input_ndim == 4:
1228
+ batch_size, channel, height, width = hidden_states.shape
1229
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1230
+
1231
+ batch_size, sequence_length, _ = (
1232
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1233
+ )
1234
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1235
+
1236
+ if attn.group_norm is not None:
1237
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1238
+
1239
+ query = attn.to_q(hidden_states)
1240
+
1241
+ if encoder_hidden_states is None:
1242
+ encoder_hidden_states = hidden_states
1243
+ else:
1244
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
1245
+ encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
1246
+ if attn.norm_cross:
1247
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1248
+
1249
+ key = attn.to_k(encoder_hidden_states)
1250
+ value = attn.to_v(encoder_hidden_states)
1251
+
1252
+ query = attn.head_to_batch_dim(query)
1253
+ key = attn.head_to_batch_dim(key)
1254
+ value = attn.head_to_batch_dim(value)
1255
+
1256
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
1257
+ hidden_states = torch.bmm(attention_probs, value)
1258
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1259
+
1260
+ # linear proj
1261
+ hidden_states = attn.to_out[0](hidden_states)
1262
+ # dropout
1263
+ hidden_states = attn.to_out[1](hidden_states)
1264
+
1265
+ if input_ndim == 4:
1266
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1267
+
1268
+ if attn.residual_connection:
1269
+ hidden_states = hidden_states + residual
1270
+
1271
+ hidden_states = hidden_states / attn.rescale_output_factor
1272
+
1273
+ return hidden_states
1274
+
1275
+
1276
+ class CNAttnProcessor2_0:
1277
+ r"""
1278
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
1279
+ """
1280
+
1281
+ def __init__(self, num_tokens=4):
1282
+ if not hasattr(F, "scaled_dot_product_attention"):
1283
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1284
+ self.num_tokens = num_tokens
1285
+
1286
+ def __call__(
1287
+ self,
1288
+ attn,
1289
+ hidden_states,
1290
+ encoder_hidden_states=None,
1291
+ attention_mask=None,
1292
+ temb=None,
1293
+ ):
1294
+ residual = hidden_states
1295
+
1296
+ if attn.spatial_norm is not None:
1297
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1298
+
1299
+ input_ndim = hidden_states.ndim
1300
+
1301
+ if input_ndim == 4:
1302
+ batch_size, channel, height, width = hidden_states.shape
1303
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1304
+
1305
+ batch_size, sequence_length, _ = (
1306
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1307
+ )
1308
+
1309
+ if attention_mask is not None:
1310
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1311
+ # scaled_dot_product_attention expects attention_mask shape to be
1312
+ # (batch, heads, source_length, target_length)
1313
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
1314
+
1315
+ if attn.group_norm is not None:
1316
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1317
+
1318
+ query = attn.to_q(hidden_states)
1319
+
1320
+ if encoder_hidden_states is None:
1321
+ encoder_hidden_states = hidden_states
1322
+ else:
1323
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
1324
+ encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
1325
+ if attn.norm_cross:
1326
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1327
+
1328
+ key = attn.to_k(encoder_hidden_states)
1329
+ value = attn.to_v(encoder_hidden_states)
1330
+
1331
+ inner_dim = key.shape[-1]
1332
+ head_dim = inner_dim // attn.heads
1333
+
1334
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1335
+
1336
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1337
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1338
+
1339
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
1340
+ # TODO: add support for attn.scale when we move to Torch 2.1
1341
+ hidden_states = F.scaled_dot_product_attention(
1342
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1343
+ )
1344
+
1345
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1346
+ hidden_states = hidden_states.to(query.dtype)
1347
+
1348
+ # linear proj
1349
+ hidden_states = attn.to_out[0](hidden_states)
1350
+ # dropout
1351
+ hidden_states = attn.to_out[1](hidden_states)
1352
+
1353
+ if input_ndim == 4:
1354
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1355
+
1356
+ if attn.residual_connection:
1357
+ hidden_states = hidden_states + residual
1358
+
1359
+ hidden_states = hidden_states / attn.rescale_output_factor
1360
+
1361
+ return hidden_states
1362
+
1363
+
1364
+ def init_attn_proc(unet, ip_adapter_tokens=16, use_lcm=False, use_adaln=True, use_external_kv=False):
1365
+ attn_procs = {}
1366
+ unet_sd = unet.state_dict()
1367
+ for name in unet.attn_processors.keys():
1368
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
1369
+ if name.startswith("mid_block"):
1370
+ hidden_size = unet.config.block_out_channels[-1]
1371
+ elif name.startswith("up_blocks"):
1372
+ block_id = int(name[len("up_blocks.")])
1373
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
1374
+ elif name.startswith("down_blocks"):
1375
+ block_id = int(name[len("down_blocks.")])
1376
+ hidden_size = unet.config.block_out_channels[block_id]
1377
+ if cross_attention_dim is None:
1378
+ if use_external_kv:
1379
+ attn_procs[name] = AdditiveKV_AttnProcessor2_0(
1380
+ hidden_size=hidden_size,
1381
+ cross_attention_dim=cross_attention_dim,
1382
+ time_embedding_dim=1280,
1383
+ ) if hasattr(F, "scaled_dot_product_attention") else AdditiveKV_AttnProcessor()
1384
+ else:
1385
+ attn_procs[name] = AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor()
1386
+ else:
1387
+ if use_adaln:
1388
+ layer_name = name.split(".processor")[0]
1389
+ if use_lcm:
1390
+ weights = {
1391
+ "to_k_ip.weight": unet_sd[layer_name + ".to_k.base_layer.weight"],
1392
+ "to_v_ip.weight": unet_sd[layer_name + ".to_v.base_layer.weight"],
1393
+ }
1394
+ else:
1395
+ weights = {
1396
+ "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
1397
+ "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
1398
+ }
1399
+ attn_procs[name] = TA_IPAttnProcessor2_0(
1400
+ hidden_size=hidden_size,
1401
+ cross_attention_dim=cross_attention_dim,
1402
+ num_tokens=ip_adapter_tokens,
1403
+ time_embedding_dim=1280,
1404
+ ) if hasattr(F, "scaled_dot_product_attention") else \
1405
+ TA_IPAttnProcessor(
1406
+ hidden_size=hidden_size,
1407
+ cross_attention_dim=cross_attention_dim,
1408
+ num_tokens=ip_adapter_tokens,
1409
+ time_embedding_dim=1280,
1410
+ )
1411
+ attn_procs[name].load_state_dict(weights, strict=False)
1412
+ else:
1413
+ attn_procs[name] = AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor()
1414
+
1415
+ return attn_procs
1416
+
1417
+
1418
+ def init_aggregator_attn_proc(unet, use_adaln=False, split_attn=False):
1419
+ attn_procs = {}
1420
+ unet_sd = unet.state_dict()
1421
+ for name in unet.attn_processors.keys():
1422
+ # get layer name and hidden dim
1423
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
1424
+ if name.startswith("mid_block"):
1425
+ hidden_size = unet.config.block_out_channels[-1]
1426
+ elif name.startswith("up_blocks"):
1427
+ block_id = int(name[len("up_blocks.")])
1428
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
1429
+ elif name.startswith("down_blocks"):
1430
+ block_id = int(name[len("down_blocks.")])
1431
+ hidden_size = unet.config.block_out_channels[block_id]
1432
+ # init attn proc
1433
+ if split_attn:
1434
+ # layer_name = name.split(".processor")[0]
1435
+ # weights = {
1436
+ # "to_q_ref.weight": unet_sd[layer_name + ".to_q.weight"],
1437
+ # "to_k_ref.weight": unet_sd[layer_name + ".to_k.weight"],
1438
+ # "to_v_ref.weight": unet_sd[layer_name + ".to_v.weight"],
1439
+ # }
1440
+ attn_procs[name] = (
1441
+ sep_split_AttnProcessor2_0(
1442
+ hidden_size=hidden_size,
1443
+ cross_attention_dim=hidden_size,
1444
+ time_embedding_dim=1280,
1445
+ )
1446
+ if use_adaln
1447
+ else split_AttnProcessor2_0(
1448
+ hidden_size=hidden_size,
1449
+ cross_attention_dim=cross_attention_dim,
1450
+ time_embedding_dim=1280,
1451
+ )
1452
+ )
1453
+ # attn_procs[name].load_state_dict(weights, strict=False)
1454
+ else:
1455
+ attn_procs[name] = (
1456
+ AttnProcessor2_0(
1457
+ hidden_size=hidden_size,
1458
+ cross_attention_dim=hidden_size,
1459
+ )
1460
+ if hasattr(F, "scaled_dot_product_attention")
1461
+ else AttnProcessor(
1462
+ hidden_size=hidden_size,
1463
+ cross_attention_dim=hidden_size,
1464
+ )
1465
+ )
1466
+
1467
+ return attn_procs