ford442 commited on
Commit
28acd08
·
verified ·
1 Parent(s): 12bfca0

Upload 10 files

Browse files
ip_adapter/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .ip_adapter import IPAdapter, IPAdapterXL, IPAdapterPlus
ip_adapter/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (234 Bytes). View file
 
ip_adapter/__pycache__/attention_processor.cpython-310.pyc ADDED
Binary file (9.71 kB). View file
 
ip_adapter/__pycache__/ip_adapter.cpython-310.pyc ADDED
Binary file (8.17 kB). View file
 
ip_adapter/__pycache__/resampler.cpython-310.pyc ADDED
Binary file (3.17 kB). View file
 
ip_adapter/__pycache__/utils.cpython-310.pyc ADDED
Binary file (362 Bytes). View file
 
ip_adapter/attention_processor.py ADDED
@@ -0,0 +1,553 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class AttnProcessor(nn.Module):
8
+ r"""
9
+ Default processor for performing attention-related computations.
10
+ """
11
+ def __init__(
12
+ self,
13
+ hidden_size=None,
14
+ cross_attention_dim=None,
15
+ ):
16
+ super().__init__()
17
+
18
+ def __call__(
19
+ self,
20
+ attn,
21
+ hidden_states,
22
+ encoder_hidden_states=None,
23
+ attention_mask=None,
24
+ temb=None,
25
+ ):
26
+ residual = hidden_states
27
+
28
+ if attn.spatial_norm is not None:
29
+ hidden_states = attn.spatial_norm(hidden_states, temb)
30
+
31
+ input_ndim = hidden_states.ndim
32
+
33
+ if input_ndim == 4:
34
+ batch_size, channel, height, width = hidden_states.shape
35
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
36
+
37
+ batch_size, sequence_length, _ = (
38
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
39
+ )
40
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
41
+
42
+ if attn.group_norm is not None:
43
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
44
+
45
+ query = attn.to_q(hidden_states)
46
+
47
+ if encoder_hidden_states is None:
48
+ encoder_hidden_states = hidden_states
49
+ elif attn.norm_cross:
50
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
51
+
52
+ key = attn.to_k(encoder_hidden_states)
53
+ value = attn.to_v(encoder_hidden_states)
54
+
55
+ query = attn.head_to_batch_dim(query)
56
+ key = attn.head_to_batch_dim(key)
57
+ value = attn.head_to_batch_dim(value)
58
+
59
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
60
+ hidden_states = torch.bmm(attention_probs, value)
61
+ hidden_states = attn.batch_to_head_dim(hidden_states)
62
+
63
+ # linear proj
64
+ hidden_states = attn.to_out[0](hidden_states)
65
+ # dropout
66
+ hidden_states = attn.to_out[1](hidden_states)
67
+
68
+ if input_ndim == 4:
69
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
70
+
71
+ if attn.residual_connection:
72
+ hidden_states = hidden_states + residual
73
+
74
+ hidden_states = hidden_states / attn.rescale_output_factor
75
+
76
+ return hidden_states
77
+
78
+
79
+ class IPAttnProcessor(nn.Module):
80
+ r"""
81
+ Attention processor for IP-Adapater.
82
+ Args:
83
+ hidden_size (`int`):
84
+ The hidden size of the attention layer.
85
+ cross_attention_dim (`int`):
86
+ The number of channels in the `encoder_hidden_states`.
87
+ scale (`float`, defaults to 1.0):
88
+ the weight scale of image prompt.
89
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
90
+ The context length of the image features.
91
+ """
92
+
93
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
94
+ super().__init__()
95
+
96
+ self.hidden_size = hidden_size
97
+ self.cross_attention_dim = cross_attention_dim
98
+ self.scale = scale
99
+ self.num_tokens = num_tokens
100
+
101
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
102
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
103
+
104
+ def __call__(
105
+ self,
106
+ attn,
107
+ hidden_states,
108
+ encoder_hidden_states=None,
109
+ attention_mask=None,
110
+ temb=None,
111
+ ):
112
+ residual = hidden_states
113
+
114
+ if attn.spatial_norm is not None:
115
+ hidden_states = attn.spatial_norm(hidden_states, temb)
116
+
117
+ input_ndim = hidden_states.ndim
118
+
119
+ if input_ndim == 4:
120
+ batch_size, channel, height, width = hidden_states.shape
121
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
122
+
123
+ batch_size, sequence_length, _ = (
124
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
125
+ )
126
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
127
+
128
+ if attn.group_norm is not None:
129
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
130
+
131
+ query = attn.to_q(hidden_states)
132
+
133
+ if encoder_hidden_states is None:
134
+ encoder_hidden_states = hidden_states
135
+ else:
136
+ # get encoder_hidden_states, ip_hidden_states
137
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
138
+ encoder_hidden_states, ip_hidden_states = encoder_hidden_states[:, :end_pos, :], encoder_hidden_states[:, end_pos:, :]
139
+ if attn.norm_cross:
140
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
141
+
142
+ key = attn.to_k(encoder_hidden_states)
143
+ value = attn.to_v(encoder_hidden_states)
144
+
145
+ query = attn.head_to_batch_dim(query)
146
+ key = attn.head_to_batch_dim(key)
147
+ value = attn.head_to_batch_dim(value)
148
+
149
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
150
+ hidden_states = torch.bmm(attention_probs, value)
151
+ hidden_states = attn.batch_to_head_dim(hidden_states)
152
+
153
+ # for ip-adapter
154
+ ip_key = self.to_k_ip(ip_hidden_states)
155
+ ip_value = self.to_v_ip(ip_hidden_states)
156
+
157
+ ip_key = attn.head_to_batch_dim(ip_key)
158
+ ip_value = attn.head_to_batch_dim(ip_value)
159
+
160
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
161
+ ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
162
+ ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
163
+
164
+ hidden_states = hidden_states + self.scale * ip_hidden_states
165
+
166
+ # linear proj
167
+ hidden_states = attn.to_out[0](hidden_states)
168
+ # dropout
169
+ hidden_states = attn.to_out[1](hidden_states)
170
+
171
+ if input_ndim == 4:
172
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
173
+
174
+ if attn.residual_connection:
175
+ hidden_states = hidden_states + residual
176
+
177
+ hidden_states = hidden_states / attn.rescale_output_factor
178
+
179
+ return hidden_states
180
+
181
+
182
+ class AttnProcessor2_0(torch.nn.Module):
183
+ r"""
184
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
185
+ """
186
+ def __init__(
187
+ self,
188
+ hidden_size=None,
189
+ cross_attention_dim=None,
190
+ ):
191
+ super().__init__()
192
+ if not hasattr(F, "scaled_dot_product_attention"):
193
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
194
+
195
+ def __call__(
196
+ self,
197
+ attn,
198
+ hidden_states,
199
+ encoder_hidden_states=None,
200
+ attention_mask=None,
201
+ temb=None,
202
+ ):
203
+ residual = hidden_states
204
+
205
+ if attn.spatial_norm is not None:
206
+ hidden_states = attn.spatial_norm(hidden_states, temb)
207
+
208
+ input_ndim = hidden_states.ndim
209
+
210
+ if input_ndim == 4:
211
+ batch_size, channel, height, width = hidden_states.shape
212
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
213
+
214
+ batch_size, sequence_length, _ = (
215
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
216
+ )
217
+
218
+ if attention_mask is not None:
219
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
220
+ # scaled_dot_product_attention expects attention_mask shape to be
221
+ # (batch, heads, source_length, target_length)
222
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
223
+
224
+ if attn.group_norm is not None:
225
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
226
+
227
+ query = attn.to_q(hidden_states)
228
+
229
+ if encoder_hidden_states is None:
230
+ encoder_hidden_states = hidden_states
231
+ elif attn.norm_cross:
232
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
233
+
234
+ key = attn.to_k(encoder_hidden_states)
235
+ value = attn.to_v(encoder_hidden_states)
236
+
237
+ inner_dim = key.shape[-1]
238
+ head_dim = inner_dim // attn.heads
239
+
240
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
241
+
242
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
243
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
244
+
245
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
246
+ # TODO: add support for attn.scale when we move to Torch 2.1
247
+ hidden_states = F.scaled_dot_product_attention(
248
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
249
+ )
250
+
251
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
252
+ hidden_states = hidden_states.to(query.dtype)
253
+
254
+ # linear proj
255
+ hidden_states = attn.to_out[0](hidden_states)
256
+ # dropout
257
+ hidden_states = attn.to_out[1](hidden_states)
258
+
259
+ if input_ndim == 4:
260
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
261
+
262
+ if attn.residual_connection:
263
+ hidden_states = hidden_states + residual
264
+
265
+ hidden_states = hidden_states / attn.rescale_output_factor
266
+
267
+ return hidden_states
268
+
269
+
270
+ class IPAttnProcessor2_0(torch.nn.Module):
271
+ r"""
272
+ Attention processor for IP-Adapater for PyTorch 2.0.
273
+ Args:
274
+ hidden_size (`int`):
275
+ The hidden size of the attention layer.
276
+ cross_attention_dim (`int`):
277
+ The number of channels in the `encoder_hidden_states`.
278
+ scale (`float`, defaults to 1.0):
279
+ the weight scale of image prompt.
280
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
281
+ The context length of the image features.
282
+ """
283
+
284
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
285
+ super().__init__()
286
+
287
+ if not hasattr(F, "scaled_dot_product_attention"):
288
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
289
+
290
+ self.hidden_size = hidden_size
291
+ self.cross_attention_dim = cross_attention_dim
292
+ self.scale = scale
293
+ self.num_tokens = num_tokens
294
+
295
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
296
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
297
+
298
+ def __call__(
299
+ self,
300
+ attn,
301
+ hidden_states,
302
+ encoder_hidden_states=None,
303
+ attention_mask=None,
304
+ temb=None,
305
+ ):
306
+ residual = hidden_states
307
+
308
+ if attn.spatial_norm is not None:
309
+ hidden_states = attn.spatial_norm(hidden_states, temb)
310
+
311
+ input_ndim = hidden_states.ndim
312
+
313
+ if input_ndim == 4:
314
+ batch_size, channel, height, width = hidden_states.shape
315
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
316
+
317
+ batch_size, sequence_length, _ = (
318
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
319
+ )
320
+
321
+ if attention_mask is not None:
322
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
323
+ # scaled_dot_product_attention expects attention_mask shape to be
324
+ # (batch, heads, source_length, target_length)
325
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
326
+
327
+ if attn.group_norm is not None:
328
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
329
+
330
+ query = attn.to_q(hidden_states)
331
+
332
+ if encoder_hidden_states is None:
333
+ encoder_hidden_states = hidden_states
334
+ else:
335
+ # get encoder_hidden_states, ip_hidden_states
336
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
337
+ encoder_hidden_states, ip_hidden_states = encoder_hidden_states[:, :end_pos, :], encoder_hidden_states[:, end_pos:, :]
338
+ if attn.norm_cross:
339
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
340
+
341
+ key = attn.to_k(encoder_hidden_states)
342
+ value = attn.to_v(encoder_hidden_states)
343
+
344
+ inner_dim = key.shape[-1]
345
+ head_dim = inner_dim // attn.heads
346
+
347
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
348
+
349
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
350
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
351
+
352
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
353
+ # TODO: add support for attn.scale when we move to Torch 2.1
354
+ hidden_states = F.scaled_dot_product_attention(
355
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
356
+ )
357
+
358
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
359
+ hidden_states = hidden_states.to(query.dtype)
360
+
361
+ # for ip-adapter
362
+ ip_key = self.to_k_ip(ip_hidden_states)
363
+ ip_value = self.to_v_ip(ip_hidden_states)
364
+
365
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
366
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
367
+
368
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
369
+ # TODO: add support for attn.scale when we move to Torch 2.1
370
+ ip_hidden_states = F.scaled_dot_product_attention(
371
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
372
+ )
373
+
374
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
375
+ ip_hidden_states = ip_hidden_states.to(query.dtype)
376
+
377
+ hidden_states = hidden_states + self.scale * ip_hidden_states
378
+
379
+ # linear proj
380
+ hidden_states = attn.to_out[0](hidden_states)
381
+ # dropout
382
+ hidden_states = attn.to_out[1](hidden_states)
383
+
384
+ if input_ndim == 4:
385
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
386
+
387
+ if attn.residual_connection:
388
+ hidden_states = hidden_states + residual
389
+
390
+ hidden_states = hidden_states / attn.rescale_output_factor
391
+
392
+ return hidden_states
393
+
394
+
395
+ ## for controlnet
396
+ class CNAttnProcessor:
397
+ r"""
398
+ Default processor for performing attention-related computations.
399
+ """
400
+
401
+ def __init__(self, num_tokens=4):
402
+ self.num_tokens = num_tokens
403
+
404
+ def __call__(
405
+ self,
406
+ attn,
407
+ hidden_states,
408
+ encoder_hidden_states=None,
409
+ attention_mask=None,
410
+ temb=None
411
+ ):
412
+ residual = hidden_states
413
+
414
+ if attn.spatial_norm is not None:
415
+ hidden_states = attn.spatial_norm(hidden_states, temb)
416
+
417
+ input_ndim = hidden_states.ndim
418
+
419
+ if input_ndim == 4:
420
+ batch_size, channel, height, width = hidden_states.shape
421
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
422
+
423
+ batch_size, sequence_length, _ = (
424
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
425
+ )
426
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
427
+
428
+ if attn.group_norm is not None:
429
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
430
+
431
+ query = attn.to_q(hidden_states)
432
+
433
+ if encoder_hidden_states is None:
434
+ encoder_hidden_states = hidden_states
435
+ else:
436
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
437
+ encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
438
+ if attn.norm_cross:
439
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
440
+
441
+ key = attn.to_k(encoder_hidden_states)
442
+ value = attn.to_v(encoder_hidden_states)
443
+
444
+ query = attn.head_to_batch_dim(query)
445
+ key = attn.head_to_batch_dim(key)
446
+ value = attn.head_to_batch_dim(value)
447
+
448
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
449
+ hidden_states = torch.bmm(attention_probs, value)
450
+ hidden_states = attn.batch_to_head_dim(hidden_states)
451
+
452
+ # linear proj
453
+ hidden_states = attn.to_out[0](hidden_states)
454
+ # dropout
455
+ hidden_states = attn.to_out[1](hidden_states)
456
+
457
+ if input_ndim == 4:
458
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
459
+
460
+ if attn.residual_connection:
461
+ hidden_states = hidden_states + residual
462
+
463
+ hidden_states = hidden_states / attn.rescale_output_factor
464
+
465
+ return hidden_states
466
+
467
+
468
+ class CNAttnProcessor2_0:
469
+ r"""
470
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
471
+ """
472
+
473
+ def __init__(self, num_tokens=4):
474
+ if not hasattr(F, "scaled_dot_product_attention"):
475
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
476
+ self.num_tokens = num_tokens
477
+
478
+ def __call__(
479
+ self,
480
+ attn,
481
+ hidden_states,
482
+ encoder_hidden_states=None,
483
+ attention_mask=None,
484
+ temb=None,
485
+ ):
486
+ residual = hidden_states
487
+
488
+ if attn.spatial_norm is not None:
489
+ hidden_states = attn.spatial_norm(hidden_states, temb)
490
+
491
+ input_ndim = hidden_states.ndim
492
+
493
+ if input_ndim == 4:
494
+ batch_size, channel, height, width = hidden_states.shape
495
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
496
+
497
+ batch_size, sequence_length, _ = (
498
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
499
+ )
500
+
501
+ if attention_mask is not None:
502
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
503
+ # scaled_dot_product_attention expects attention_mask shape to be
504
+ # (batch, heads, source_length, target_length)
505
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
506
+
507
+ if attn.group_norm is not None:
508
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
509
+
510
+ query = attn.to_q(hidden_states)
511
+
512
+ if encoder_hidden_states is None:
513
+ encoder_hidden_states = hidden_states
514
+ else:
515
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
516
+ encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
517
+ if attn.norm_cross:
518
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
519
+
520
+ key = attn.to_k(encoder_hidden_states)
521
+ value = attn.to_v(encoder_hidden_states)
522
+
523
+ inner_dim = key.shape[-1]
524
+ head_dim = inner_dim // attn.heads
525
+
526
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
527
+
528
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
529
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
530
+
531
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
532
+ # TODO: add support for attn.scale when we move to Torch 2.1
533
+ hidden_states = F.scaled_dot_product_attention(
534
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
535
+ )
536
+
537
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
538
+ hidden_states = hidden_states.to(query.dtype)
539
+
540
+ # linear proj
541
+ hidden_states = attn.to_out[0](hidden_states)
542
+ # dropout
543
+ hidden_states = attn.to_out[1](hidden_states)
544
+
545
+ if input_ndim == 4:
546
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
547
+
548
+ if attn.residual_connection:
549
+ hidden_states = hidden_states + residual
550
+
551
+ hidden_states = hidden_states / attn.rescale_output_factor
552
+
553
+ return hidden_states
ip_adapter/ip_adapter.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ import torch
5
+ from diffusers import StableDiffusionPipeline
6
+ from diffusers.pipelines.controlnet import MultiControlNetModel
7
+ from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
8
+ from PIL import Image
9
+
10
+ from .utils import is_torch2_available
11
+ if is_torch2_available():
12
+ from .attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor, CNAttnProcessor2_0 as CNAttnProcessor
13
+ else:
14
+ from .attention_processor import IPAttnProcessor, AttnProcessor, CNAttnProcessor
15
+ from .resampler import Resampler
16
+
17
+
18
+ class ImageProjModel(torch.nn.Module):
19
+ """Projection Model"""
20
+ def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
21
+ super().__init__()
22
+
23
+ self.cross_attention_dim = cross_attention_dim
24
+ self.clip_extra_context_tokens = clip_extra_context_tokens
25
+ self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
26
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
27
+
28
+ def forward(self, image_embeds):
29
+ embeds = image_embeds
30
+ clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)
31
+ clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
32
+ return clip_extra_context_tokens
33
+
34
+
35
+ class IPAdapter:
36
+
37
+ def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4):
38
+
39
+ self.device = device
40
+ self.image_encoder_path = image_encoder_path
41
+ self.ip_ckpt = ip_ckpt
42
+ self.num_tokens = num_tokens
43
+
44
+ self.pipe = sd_pipe.to(self.device)
45
+ self.set_ip_adapter()
46
+
47
+ # load image encoder
48
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(self.device, dtype=torch.bfloat16)
49
+ self.clip_image_processor = CLIPImageProcessor()
50
+ # image proj model
51
+ self.image_proj_model = self.init_proj()
52
+ self.load_ip_adapter()
53
+
54
+ def init_proj(self):
55
+ image_proj_model = ImageProjModel(
56
+ cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
57
+ clip_embeddings_dim=self.image_encoder.config.projection_dim,
58
+ clip_extra_context_tokens=self.num_tokens,
59
+ ).to(self.device, dtype=torch.bfloat16)
60
+ return image_proj_model
61
+
62
+ def set_ip_adapter(self):
63
+ unet = self.pipe.unet
64
+ attn_procs = {}
65
+ for name in unet.attn_processors.keys():
66
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
67
+ if name.startswith("mid_block"):
68
+ hidden_size = unet.config.block_out_channels[-1]
69
+ elif name.startswith("up_blocks"):
70
+ block_id = int(name[len("up_blocks.")])
71
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
72
+ elif name.startswith("down_blocks"):
73
+ block_id = int(name[len("down_blocks.")])
74
+ hidden_size = unet.config.block_out_channels[block_id]
75
+ if cross_attention_dim is None:
76
+ attn_procs[name] = AttnProcessor()
77
+ else:
78
+ attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim,
79
+ scale=1.0,num_tokens= self.num_tokens).to(self.device, dtype=torch.bfloat16)
80
+ unet.set_attn_processor(attn_procs)
81
+ if hasattr(self.pipe, "controlnet"):
82
+ if isinstance(self.pipe.controlnet, MultiControlNetModel):
83
+ for controlnet in self.pipe.controlnet.nets:
84
+ controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
85
+ else:
86
+ self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
87
+
88
+ def update_state_dict(self, state_dict):
89
+ image_proj_dict = {}
90
+ ip_adapter_dict = {}
91
+
92
+ for k in state_dict.keys():
93
+ if k.startswith("image_proj_model"):
94
+ image_proj_dict[k.replace("image_proj_model.", "")] = state_dict[k]
95
+ if k.startswith("adapter_modules"):
96
+ ip_adapter_dict[k.replace("adapter_modules.", "")] = state_dict[k]
97
+
98
+ dict = {'image_proj': image_proj_dict,
99
+ 'ip_adapter' : ip_adapter_dict
100
+ }
101
+ return dict
102
+
103
+ def load_ip_adapter(self):
104
+ state_dict = torch.load(self.ip_ckpt, map_location="cpu")
105
+ if "image_proj_model.proj.weight" in state_dict.keys():
106
+ state_dict = self.update_state_dict(state_dict)
107
+ self.image_proj_model.load_state_dict(state_dict["image_proj"])
108
+ ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
109
+ ip_layers.load_state_dict(state_dict["ip_adapter"])
110
+
111
+ @torch.inference_mode()
112
+ def get_image_embeds(self, pil_image):
113
+ if isinstance(pil_image, Image.Image):
114
+ pil_image = [pil_image]
115
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
116
+ clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.bfloat16)).image_embeds
117
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
118
+ uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
119
+ return image_prompt_embeds, uncond_image_prompt_embeds
120
+
121
+ def set_scale(self, scale):
122
+ for attn_processor in self.pipe.unet.attn_processors.values():
123
+ if isinstance(attn_processor, IPAttnProcessor):
124
+ attn_processor.scale = scale
125
+
126
+ def generate(
127
+ self,
128
+ pil_image,
129
+ prompt=None,
130
+ negative_prompt=None,
131
+ scale=1.0,
132
+ num_samples=4,
133
+ seed=-1,
134
+ guidance_scale=7.5,
135
+ num_inference_steps=30,
136
+ **kwargs,
137
+ ):
138
+ self.set_scale(scale)
139
+
140
+ if isinstance(pil_image, List):
141
+ num_prompts = len(pil_image)
142
+ else:
143
+ num_prompts = 1
144
+
145
+ # if isinstance(pil_image, Image.Image):
146
+ # num_prompts = 1
147
+ # else:
148
+ # num_prompts = len(pil_image)
149
+ # print("num promp", num_prompts)
150
+
151
+ if prompt is None:
152
+ prompt = "best quality, high quality"
153
+ if negative_prompt is None:
154
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
155
+
156
+ if not isinstance(prompt, List):
157
+ prompt = [prompt] * num_prompts
158
+ if not isinstance(negative_prompt, List):
159
+ negative_prompt = [negative_prompt] * num_prompts
160
+
161
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
162
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
163
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
164
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
165
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
166
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
167
+
168
+ with torch.inference_mode():
169
+ prompt_embeds = self.pipe._encode_prompt(
170
+ prompt, device=self.device, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt)
171
+ negative_prompt_embeds_, prompt_embeds_ = prompt_embeds.chunk(2)
172
+
173
+ prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
174
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
175
+
176
+ generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
177
+ images = self.pipe(
178
+ prompt_embeds=prompt_embeds,
179
+ negative_prompt_embeds=negative_prompt_embeds,
180
+ guidance_scale=guidance_scale,
181
+ num_inference_steps=num_inference_steps,
182
+ generator=generator,
183
+ **kwargs,
184
+ ).images
185
+
186
+ return images
187
+
188
+
189
+ class IPAdapterXL(IPAdapter):
190
+ """SDXL"""
191
+
192
+ def generate(
193
+ self,
194
+ pil_image,
195
+ prompt=None,
196
+ negative_prompt=None,
197
+ scale=1.0,
198
+ num_samples=4,
199
+ seed=-1,
200
+ num_inference_steps=30,
201
+ **kwargs,
202
+ ):
203
+ self.set_scale(scale)
204
+
205
+ if isinstance(pil_image, Image.Image):
206
+ num_prompts = 1
207
+ else:
208
+ num_prompts = len(pil_image)
209
+
210
+ if prompt is None:
211
+ prompt = "best quality, high quality"
212
+ if negative_prompt is None:
213
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
214
+
215
+ if not isinstance(prompt, List):
216
+ prompt = [prompt] * num_prompts
217
+ if not isinstance(negative_prompt, List):
218
+ negative_prompt = [negative_prompt] * num_prompts
219
+
220
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
221
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
222
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
223
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
224
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
225
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
226
+
227
+ with torch.inference_mode():
228
+ prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = self.pipe.encode_prompt(
229
+ prompt, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt)
230
+ prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
231
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
232
+
233
+ generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
234
+ images = self.pipe(
235
+ prompt_embeds=prompt_embeds,
236
+ negative_prompt_embeds=negative_prompt_embeds,
237
+ pooled_prompt_embeds=pooled_prompt_embeds,
238
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
239
+ num_inference_steps=num_inference_steps,
240
+ generator=generator,
241
+ **kwargs,
242
+ ).images
243
+
244
+ return images
245
+
246
+
247
+ class IPAdapterPlus(IPAdapter):
248
+ """IP-Adapter with fine-grained features"""
249
+
250
+ def init_proj(self):
251
+ image_proj_model = Resampler(
252
+ dim=self.pipe.unet.config.cross_attention_dim,
253
+ depth=4,
254
+ dim_head=64,
255
+ heads=12,
256
+ num_queries=self.num_tokens,
257
+ embedding_dim=self.image_encoder.config.hidden_size,
258
+ output_dim=self.pipe.unet.config.cross_attention_dim,
259
+ ff_mult=4
260
+ ).to(self.device, dtype=torch.bfloat16)
261
+ return image_proj_model
262
+
263
+ @torch.inference_mode()
264
+ def get_image_embeds(self, pil_image):
265
+ if isinstance(pil_image, Image.Image):
266
+ pil_image = [pil_image]
267
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
268
+ clip_image = clip_image.to(self.device, dtype=torch.bfloat16)
269
+ clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
270
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
271
+ uncond_clip_image_embeds = self.image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[-2]
272
+ uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
273
+ return image_prompt_embeds, uncond_image_prompt_embeds
ip_adapter/resampler.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+
8
+ # FFN
9
+ def FeedForward(dim, mult=4):
10
+ inner_dim = int(dim * mult)
11
+ return nn.Sequential(
12
+ nn.LayerNorm(dim),
13
+ nn.Linear(dim, inner_dim, bias=False),
14
+ nn.GELU(),
15
+ nn.Linear(inner_dim, dim, bias=False),
16
+ )
17
+
18
+
19
+ def reshape_tensor(x, heads):
20
+ bs, length, width = x.shape
21
+ #(bs, length, width) --> (bs, length, n_heads, dim_per_head)
22
+ x = x.view(bs, length, heads, -1)
23
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
24
+ x = x.transpose(1, 2)
25
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
26
+ x = x.reshape(bs, heads, length, -1)
27
+ return x
28
+
29
+
30
+ class PerceiverAttention(nn.Module):
31
+ def __init__(self, *, dim, dim_head=64, heads=8):
32
+ super().__init__()
33
+ self.scale = dim_head**-0.5
34
+ self.dim_head = dim_head
35
+ self.heads = heads
36
+ inner_dim = dim_head * heads
37
+
38
+ self.norm1 = nn.LayerNorm(dim)
39
+ self.norm2 = nn.LayerNorm(dim)
40
+
41
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
42
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
43
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
44
+
45
+
46
+ def forward(self, x, latents):
47
+ """
48
+ Args:
49
+ x (torch.Tensor): image features
50
+ shape (b, n1, D)
51
+ latent (torch.Tensor): latent features
52
+ shape (b, n2, D)
53
+ """
54
+ x = self.norm1(x)
55
+ latents = self.norm2(latents)
56
+
57
+ b, l, _ = latents.shape
58
+
59
+ q = self.to_q(latents)
60
+ kv_input = torch.cat((x, latents), dim=-2)
61
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
62
+
63
+ q = reshape_tensor(q, self.heads)
64
+ k = reshape_tensor(k, self.heads)
65
+ v = reshape_tensor(v, self.heads)
66
+
67
+ # attention
68
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
69
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
70
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
71
+ out = weight @ v
72
+
73
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
74
+
75
+ return self.to_out(out)
76
+
77
+
78
+ class Resampler(nn.Module):
79
+ def __init__(
80
+ self,
81
+ dim=1024,
82
+ depth=8,
83
+ dim_head=64,
84
+ heads=16,
85
+ num_queries=8,
86
+ embedding_dim=768,
87
+ output_dim=1024,
88
+ ff_mult=4,
89
+ ):
90
+ super().__init__()
91
+
92
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
93
+
94
+ self.proj_in = nn.Linear(embedding_dim, dim)
95
+
96
+ self.proj_out = nn.Linear(dim, output_dim)
97
+ self.norm_out = nn.LayerNorm(output_dim)
98
+
99
+ self.layers = nn.ModuleList([])
100
+ for _ in range(depth):
101
+ self.layers.append(
102
+ nn.ModuleList(
103
+ [
104
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
105
+ FeedForward(dim=dim, mult=ff_mult),
106
+ ]
107
+ )
108
+ )
109
+
110
+ def forward(self, x):
111
+
112
+ latents = self.latents.repeat(x.size(0), 1, 1)
113
+
114
+ x = self.proj_in(x)
115
+
116
+ for attn, ff in self.layers:
117
+ latents = attn(x, latents) + latents
118
+ latents = ff(latents) + latents
119
+
120
+ latents = self.proj_out(latents)
121
+ return self.norm_out(latents)
ip_adapter/utils.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import torch.nn.functional as F
2
+
3
+
4
+ def is_torch2_available():
5
+ return hasattr(F, "scaled_dot_product_attention")