Freiburg-AI-Research commited on
Commit
a0c720e
Β·
1 Parent(s): 0990b67

Upload 6 files

Browse files
glide_text2im/clip/__init__.py ADDED
File without changes
glide_text2im/clip/attention.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from abc import ABC, abstractmethod
3
+ from itertools import product
4
+ from typing import Any, Optional
5
+
6
+ import attr
7
+ import numpy as np
8
+ import torch
9
+
10
+
11
+ @attr.s
12
+ class AttentionMask(ABC):
13
+ query_context_size: int = attr.ib(validator=lambda i, a, x: x >= 1) # type: ignore
14
+ key_context_size: int = attr.ib(validator=lambda i, a, x: x >= 1) # type: ignore
15
+ block_size: int = attr.ib(validator=lambda i, a, x: x >= 1) # type: ignore
16
+ n_head: int = attr.ib(validator=lambda i, a, x: x >= 1) # type: ignore
17
+ is_head_specific: bool = attr.ib(default=False)
18
+ n_query_pad: int = attr.ib(default=0)
19
+ n_key_pad: int = attr.ib(default=0)
20
+
21
+ def __attrs_post_init__(self) -> None:
22
+ if self.query_context_size % self.block_size != 0:
23
+ raise ValueError()
24
+ if self.key_context_size % self.block_size != 0:
25
+ raise ValueError()
26
+ if self.n_query_pad >= self.query_context_size:
27
+ raise ValueError()
28
+ if self.n_key_pad >= self.key_context_size:
29
+ raise ValueError()
30
+
31
+ self.n_query_block = self.query_context_size // self.block_size
32
+ self.n_key_block = self.key_context_size // self.block_size
33
+ self.first_pad_query_block_idx = self.n_query_block - int(
34
+ math.ceil(self.n_query_pad / self.block_size)
35
+ )
36
+ self.first_pad_key_block_idx = self.n_key_block - int(
37
+ math.ceil(self.n_key_pad / self.block_size)
38
+ )
39
+
40
+ def _make_global_layout(self) -> None:
41
+ if not self.is_head_specific:
42
+ m = np.ones([self.n_query_block, self.n_key_block], dtype=np.bool)
43
+ r = product(*[range(n) for n in m.shape])
44
+
45
+ for qb, kb in r:
46
+ m[qb, kb] = np.any(self.block_layout(None, 0, qb, kb, 0))
47
+ else:
48
+ m = np.ones([self.n_head, self.n_query_block, self.n_key_block], dtype=np.bool)
49
+ r = product(*[range(n) for n in m.shape])
50
+
51
+ for h, qb, kb in r:
52
+ m[h, qb, kb] = np.any(self.block_layout(None, h, qb, kb, 0))
53
+
54
+ self.global_layout = m
55
+
56
+ @abstractmethod
57
+ def _block_layout(
58
+ self, blk_shape: Any, head_idx: int, query_idx: int, key_idx: int, blk_idx: int
59
+ ) -> np.ndarray:
60
+ raise NotImplementedError()
61
+
62
+ def block_layout(
63
+ self, blk_shape: Any, head_idx: int, query_idx: int, key_idx: int, blk_idx: int
64
+ ) -> np.ndarray:
65
+ """
66
+ `query_idx`, `key_idx` are block-level, zero-based indices.
67
+ """
68
+
69
+ m = np.ones([self.block_size, self.block_size], dtype=np.bool)
70
+
71
+ if query_idx >= self.first_pad_query_block_idx:
72
+ n_pad = min(
73
+ self.block_size,
74
+ (query_idx + 1) * self.block_size - (self.query_context_size - self.n_query_pad),
75
+ )
76
+ assert n_pad > 0
77
+ m[self.block_size - n_pad :] = False
78
+ if key_idx >= self.first_pad_key_block_idx:
79
+ n_pad = min(
80
+ self.block_size,
81
+ (key_idx + 1) * self.block_size - (self.key_context_size - self.n_key_pad),
82
+ )
83
+ assert n_pad > 0
84
+ m[:, self.block_size - n_pad :] = False
85
+
86
+ return m & self._block_layout(blk_shape, head_idx, query_idx, key_idx, blk_idx)
87
+
88
+
89
+ @attr.s
90
+ class DenseAttentionMask(AttentionMask):
91
+ def __attrs_post_init__(self) -> None:
92
+ super().__attrs_post_init__()
93
+
94
+ self.global_layout = np.ones([self.n_query_block, self.n_key_block], dtype=np.bool)
95
+ n_zero_query_blocks = self.n_query_pad // self.block_size
96
+ n_zero_key_blocks = self.n_key_pad // self.block_size
97
+ self.global_layout[self.n_query_block - n_zero_query_blocks :] = False
98
+ self.global_layout[:, self.n_key_block - n_zero_key_blocks :] = False
99
+
100
+ def _block_layout(
101
+ self, blk_shape: Any, head_idx: int, query_idx: int, key_idx: int, blk_idx: int
102
+ ) -> np.ndarray:
103
+ return np.ones([self.block_size, self.block_size], dtype=np.bool)
104
+
105
+
106
+ @attr.s
107
+ class DenseCausalAttentionMask(AttentionMask):
108
+ def __attrs_post_init__(self) -> None:
109
+ super().__attrs_post_init__()
110
+
111
+ self.global_layout = np.tril(np.ones([self.n_query_block, self.n_key_block], dtype=np.bool))
112
+ n_zero_query_blocks = self.n_query_pad // self.block_size
113
+ n_zero_key_blocks = self.n_key_pad // self.block_size
114
+ self.global_layout[self.n_query_block - n_zero_query_blocks :] = False
115
+ self.global_layout[:, self.n_key_block - n_zero_key_blocks :] = False
116
+
117
+ def _block_layout(
118
+ self, blk_shape: Any, head_idx: int, query_idx: int, key_idx: int, blk_idx: int
119
+ ) -> np.ndarray:
120
+ if query_idx > key_idx:
121
+ return np.ones(2 * [self.block_size], dtype=np.bool)
122
+ elif query_idx < key_idx:
123
+ return np.zeros(2 * [self.block_size], dtype=np.bool)
124
+ else:
125
+ return np.tril(np.ones(2 * [self.block_size], dtype=np.bool))
126
+
127
+
128
+ @attr.s(eq=False, repr=False)
129
+ class AttentionInfo:
130
+ n_heads: int = attr.ib()
131
+ ctx_blks_q: int = attr.ib()
132
+ ctx_blks_k: int = attr.ib()
133
+ block_size: int = attr.ib()
134
+ pytorch_attn_bias: Optional[torch.Tensor] = attr.ib()
135
+
136
+
137
+ def to_attention_info(d: AttentionMask) -> AttentionInfo:
138
+ return AttentionInfo(
139
+ n_heads=d.n_head,
140
+ ctx_blks_q=d.n_query_block,
141
+ ctx_blks_k=d.n_key_block,
142
+ block_size=d.block_size,
143
+ pytorch_attn_bias=None,
144
+ )
145
+
146
+
147
+ def make_full_layout(d: AttentionMask) -> np.ndarray:
148
+ """
149
+ Returns the `context_size x context_size` layout matrix described by `d`. If the layout is dependent on the index of
150
+ the attention head, a `attention_head x context_size x context_size` layout matrix is returned instead.
151
+ """
152
+
153
+ if not d.is_head_specific:
154
+ u = np.reshape(d.global_layout, [d.n_query_block, d.n_key_block, 1, 1])
155
+ r = product(range(d.n_query_block), range(d.n_key_block))
156
+ v = np.array([d.block_layout(None, 0, i, j, 0) for i, j in r])
157
+ v = np.reshape(v, [d.n_query_block, d.n_key_block, d.block_size, d.block_size])
158
+
159
+ w = u * v
160
+ w = np.transpose(w, [0, 2, 1, 3])
161
+ w = np.reshape(w, [d.query_context_size, d.key_context_size])
162
+ return w
163
+ else:
164
+ if len(d.global_layout.shape) == 2:
165
+ u = np.reshape(d.global_layout, [1, d.n_query_block, d.n_key_block, 1, 1])
166
+ u = np.tile(u, [d.n_head, 1, 1, 1, 1])
167
+ elif len(d.global_layout.shape) == 3:
168
+ u = np.reshape(d.global_layout, [d.n_head, d.n_query_block, d.n_key_block, 1, 1])
169
+ else:
170
+ raise RuntimeError()
171
+
172
+ s = product(range(d.n_head), range(d.n_query_block), range(d.n_key_block))
173
+ v = np.array([d.block_layout(None, i, j, k, 0) for i, j, k in s])
174
+ v = np.reshape(v, [d.n_head, d.n_query_block, d.n_key_block, d.block_size, d.block_size])
175
+
176
+ w = u * v
177
+ w = np.transpose(w, [0, 1, 3, 2, 4])
178
+ w = np.reshape(w, [d.n_head, d.query_context_size, d.key_context_size])
179
+ return w
glide_text2im/clip/config.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ logit_scale: 100.0
2
+
3
+ # Diffusion settings
4
+ beta_schedule: "squaredcos_cap_v2"
5
+ n_timesteps: 1000
6
+
7
+ # Architecture settings
8
+ image_size: 64
9
+ patch_size: 4
10
+ n_vocab: 65536
11
+ max_text_len: 77
12
+ n_embd: 512
13
+ n_head_state_text: 64
14
+ n_head_text: 8
15
+ n_xf_blocks_text: 12
16
+ n_head_state_image: 64
17
+ n_head_image: 12
18
+ n_xf_blocks_image: 12
glide_text2im/clip/encoders.py ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from collections import OrderedDict
3
+ from typing import List, Optional, Tuple, cast
4
+
5
+ import attr
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from .attention import (
12
+ AttentionInfo,
13
+ DenseAttentionMask,
14
+ DenseCausalAttentionMask,
15
+ make_full_layout,
16
+ to_attention_info,
17
+ )
18
+ from .utils import Affine, LayerNorm, zero_key_bias_grad
19
+
20
+ # Constants used in the original CLIP implementation.
21
+ image_channel_means = [122.77093945, 116.74601272, 104.09373519]
22
+ image_channel_stds = [68.50053285, 66.63215831, 70.32316309]
23
+
24
+
25
+ @attr.s(eq=False, repr=False)
26
+ class TextEmbedding(nn.Module):
27
+ n_vocab: int = attr.ib()
28
+ n_context: int = attr.ib()
29
+ n_state: int = attr.ib()
30
+ device: torch.device = attr.ib(default=torch.device("cuda"))
31
+
32
+ def __attrs_post_init__(self) -> None:
33
+ super().__init__()
34
+
35
+ w_voc = torch.empty((self.n_vocab, self.n_state), dtype=torch.float32, device=self.device)
36
+ w_pos = torch.empty((self.n_context, self.n_state), dtype=torch.float32, device=self.device)
37
+
38
+ with torch.no_grad():
39
+ w_voc.normal_(std=0.02)
40
+ w_pos.normal_(std=0.01)
41
+
42
+ self.w_voc = nn.Parameter(w_voc)
43
+ self.w_pos = nn.Parameter(w_pos)
44
+
45
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
46
+ if len(x.shape) != 2:
47
+ raise ValueError()
48
+
49
+ return F.embedding(x, self.w_voc) + self.w_pos[None, :, :]
50
+
51
+
52
+ @attr.s(eq=False, repr=False)
53
+ class ImageEmbedding(nn.Module):
54
+ image_size: int = attr.ib()
55
+ patch_size: int = attr.ib()
56
+ n_state: int = attr.ib()
57
+ n_timestep: int = attr.ib(default=0)
58
+ device: torch.device = attr.ib(default=torch.device("cuda"))
59
+
60
+ def __attrs_post_init__(self) -> None:
61
+ super().__init__()
62
+
63
+ if self.image_size % self.patch_size != 0:
64
+ raise ValueError()
65
+
66
+ n_patch = self.image_size // self.patch_size
67
+ patch_proj = torch.empty(
68
+ (self.n_state, 3) + 2 * (self.patch_size,), dtype=torch.float32, device=self.device
69
+ )
70
+ w_pos = torch.empty(
71
+ (1 + n_patch ** 2, self.n_state), dtype=torch.float32, device=self.device
72
+ )
73
+
74
+ with torch.no_grad():
75
+ if self.n_timestep == 0:
76
+ pred_state = torch.empty((self.n_state,), dtype=torch.float32, device=self.device)
77
+ pred_state.normal_(std=1 / np.sqrt(self.n_state))
78
+ self.pred_state = nn.Parameter(pred_state)
79
+ else:
80
+ w_t = torch.empty(
81
+ (self.n_timestep, self.n_state), dtype=torch.float32, device=self.device
82
+ )
83
+ w_t.normal_(std=1 / np.sqrt(self.n_state))
84
+ self.w_t = nn.Parameter(w_t)
85
+
86
+ patch_proj.normal_(std=np.sqrt(2 / (self.n_state * self.patch_size ** 2)))
87
+ w_pos.normal_(std=1 / np.sqrt(self.n_state))
88
+
89
+ self.patch_proj = nn.Parameter(patch_proj)
90
+ self.w_pos = nn.Parameter(w_pos)
91
+
92
+ self.channel_means = torch.tensor(
93
+ image_channel_means, dtype=torch.float32, device=self.device
94
+ )[None, :, None, None]
95
+ self.channel_stds = torch.tensor(
96
+ image_channel_stds, dtype=torch.float32, device=self.device
97
+ )[None, :, None, None]
98
+ self.ln = LayerNorm(self.n_state, eps=1e-5, device=self.device)
99
+
100
+ def forward(self, x: torch.Tensor, t: Optional[torch.Tensor] = None) -> torch.Tensor:
101
+ if len(x.shape) != 4:
102
+ raise ValueError("input should be 4d")
103
+ if x.shape[1] != 3:
104
+ raise ValueError("input should have 3 channels")
105
+ if not (x.shape[2] == self.image_size and x.shape[3] == self.image_size):
106
+ raise ValueError(f"input is not {self.image_size} x {self.image_size}")
107
+
108
+ if (self.n_timestep == 0 and t is not None) or (self.n_timestep != 0 and t is None):
109
+ raise ValueError()
110
+ if self.n_timestep != 0:
111
+ assert t is not None
112
+ if len(t.shape) != 1:
113
+ raise ValueError()
114
+ if t.shape[0] != x.shape[0]:
115
+ raise ValueError()
116
+
117
+ x = (x - self.channel_means) / self.channel_stds
118
+ x = F.conv2d(x, self.patch_proj, stride=self.patch_size)
119
+ x = x.reshape(x.shape[0], self.n_state, (self.image_size // self.patch_size) ** 2).permute(
120
+ 0, 2, 1
121
+ )
122
+
123
+ sot = (
124
+ self.pred_state[None, None].expand(x.shape[0], -1, -1)
125
+ if self.n_timestep == 0
126
+ else F.embedding(cast(torch.Tensor, t), self.w_t)[:, None]
127
+ )
128
+ x = torch.cat((sot, x), dim=1) + self.w_pos[None]
129
+ return self.ln(x)
130
+
131
+
132
+ @attr.s(eq=False, repr=False)
133
+ class AttentionResblock(nn.Module):
134
+ n_state: int = attr.ib()
135
+ n_resblocks: int = attr.ib()
136
+ attn_fn: AttentionInfo = attr.ib()
137
+ device: torch.device = attr.ib(default=torch.device("cuda"))
138
+
139
+ def __attrs_post_init__(self) -> None:
140
+ super().__init__()
141
+
142
+ self.n_head_state = self.n_state // self.attn_fn.n_heads
143
+ self.qk_scale = 1 / np.sqrt(self.n_head_state)
144
+
145
+ self.ln = LayerNorm(self.n_state, eps=1e-5, device=self.device)
146
+ self.f_q = Affine(
147
+ self.n_state,
148
+ self.n_state,
149
+ std=1 / math.sqrt(self.n_state),
150
+ use_bias=True,
151
+ bias_filter_fn=zero_key_bias_grad,
152
+ device=self.device,
153
+ )
154
+ self.f_k = Affine(
155
+ self.n_state,
156
+ self.n_state,
157
+ std=1 / math.sqrt(self.n_state),
158
+ use_bias=False,
159
+ bias_filter_fn=zero_key_bias_grad,
160
+ device=self.device,
161
+ )
162
+ self.f_v = Affine(
163
+ self.n_state,
164
+ self.n_state,
165
+ std=1 / math.sqrt(self.n_state),
166
+ use_bias=True,
167
+ bias_filter_fn=zero_key_bias_grad,
168
+ device=self.device,
169
+ )
170
+ self.f_c = Affine(
171
+ self.n_state,
172
+ self.n_state,
173
+ use_bias=True,
174
+ std=1 / np.sqrt(self.n_state * self.n_resblocks ** 2),
175
+ device=self.device,
176
+ ) # XXX
177
+
178
+ def forward(self, m: torch.Tensor) -> torch.Tensor:
179
+ n_context = m.shape[1]
180
+ n_query_pad = self.attn_fn.ctx_blks_q * self.attn_fn.block_size - n_context
181
+ n_key_pad = self.attn_fn.ctx_blks_k * self.attn_fn.block_size - n_context
182
+ assert n_query_pad >= 0
183
+ assert n_key_pad >= 0
184
+
185
+ r = m
186
+ r = self.ln(r)
187
+ q, k, v = self.f_q(r), self.f_k(r), self.f_v(r)
188
+
189
+ if n_query_pad != 0:
190
+ q = F.pad(q, (0, 0, 0, n_query_pad))
191
+
192
+ if n_key_pad != 0:
193
+ k = F.pad(k, (0, 0, 0, n_key_pad))
194
+ v = F.pad(v, (0, 0, 0, n_key_pad))
195
+
196
+ q = q.view([q.shape[0], -1, self.attn_fn.n_heads, self.n_head_state]).permute((0, 2, 1, 3))
197
+ k = k.view([k.shape[0], -1, self.attn_fn.n_heads, self.n_head_state]).permute((0, 2, 1, 3))
198
+ v = v.view([v.shape[0], -1, self.attn_fn.n_heads, self.n_head_state]).permute((0, 2, 1, 3))
199
+ w = torch.einsum(
200
+ "bhcd,bhkd->bhck", q * math.sqrt(self.qk_scale), k * math.sqrt(self.qk_scale)
201
+ )
202
+
203
+ if hasattr(self.attn_fn, "pytorch_attn_bias"):
204
+ bias = self.attn_fn.pytorch_attn_bias
205
+ assert len(bias.shape) in {2, 3}
206
+
207
+ if len(bias.shape) == 2:
208
+ w = torch.softmax(w + self.attn_fn.pytorch_attn_bias[None, None], dim=-1)
209
+ elif len(bias.shape) == 3:
210
+ w = torch.softmax(w + self.attn_fn.pytorch_attn_bias[None], dim=-1)
211
+ else:
212
+ w = torch.softmax(w, dim=-1)
213
+
214
+ r = torch.einsum("bhck,bhkd->bhcd", w, v)
215
+ r = r.permute((0, 2, 1, 3)).reshape((r.shape[0], -1, self.n_state))
216
+
217
+ if n_query_pad != 0:
218
+ r = r[:, :-n_query_pad]
219
+
220
+ assert r.shape[1] == n_context
221
+
222
+ r = self.f_c(r)
223
+ return m + r
224
+
225
+
226
+ @attr.s(eq=False, repr=False)
227
+ class FullyConnectedResblock(nn.Module):
228
+ """
229
+ Not imported from other files because we retain Alec's original inits.
230
+ """
231
+
232
+ n_state: int = attr.ib()
233
+ n_resblocks: int = attr.ib()
234
+ device: torch.device = attr.ib(default=torch.device("cuda"))
235
+
236
+ def __attrs_post_init__(self) -> None:
237
+ super().__init__()
238
+
239
+ self.ln = LayerNorm(self.n_state, eps=1e-5, device=self.device)
240
+ self.f_1 = Affine(
241
+ self.n_state,
242
+ 4 * self.n_state,
243
+ use_bias=True,
244
+ std=np.sqrt(2 / (4 * self.n_state)),
245
+ device=self.device,
246
+ )
247
+ self.f_2 = Affine(
248
+ 4 * self.n_state,
249
+ self.n_state,
250
+ use_bias=True,
251
+ std=1 / np.sqrt(self.n_state * self.n_resblocks ** 2),
252
+ device=self.device,
253
+ ) # XXX
254
+
255
+ def forward(self, m: torch.Tensor) -> torch.Tensor:
256
+ r = m
257
+ r = self.ln(r)
258
+
259
+ r = self.f_2(F.gelu(self.f_1(r)))
260
+ return m + r
261
+
262
+
263
+ @attr.s(eq=False, repr=False)
264
+ class TransformerBlock(nn.Module):
265
+ n_state: int = attr.ib()
266
+ n_resblocks: int = attr.ib()
267
+ attn_fn: AttentionInfo = attr.ib()
268
+ device: torch.device = attr.ib(default=torch.device("cuda"))
269
+
270
+ def __attrs_post_init__(self) -> None:
271
+ super().__init__()
272
+
273
+ self.f_attn = AttentionResblock(
274
+ self.n_state,
275
+ self.n_resblocks,
276
+ self.attn_fn,
277
+ self.device,
278
+ )
279
+ self.f_mlp = FullyConnectedResblock(self.n_state, self.n_resblocks, self.device)
280
+
281
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
282
+ return self.f_mlp(self.f_attn(x))
283
+
284
+
285
+ @attr.s(eq=False, repr=False)
286
+ class TextFeatureExtractor(nn.Module):
287
+ n_state: int = attr.ib()
288
+ n_embd: int = attr.ib()
289
+ device: torch.device = attr.ib(default=torch.device("cuda"))
290
+
291
+ def __attrs_post_init__(self) -> None:
292
+ super().__init__()
293
+
294
+ self.ln = LayerNorm(self.n_state, eps=1e-5, device=self.device)
295
+ self.f = Affine(self.n_state, self.n_embd, use_bias=False, device=self.device)
296
+
297
+ def forward(
298
+ self, text: torch.Tensor, text_len: torch.Tensor, return_probe_features: bool = False
299
+ ) -> torch.Tensor:
300
+ if len(text.shape) != 3:
301
+ raise ValueError("expected text to be 3d")
302
+ if len(text_len.shape) != 1:
303
+ raise ValueError("expected text length to be 1d")
304
+ if text.shape[0] != text_len.shape[0]:
305
+ raise ValueError("text and text_len have inconsistent batch dimensions")
306
+
307
+ index = (text_len - 1)[:, None, None].expand(-1, 1, text.shape[2])
308
+ x = torch.gather(text, dim=1, index=index)
309
+ assert list(x.shape) == [text.shape[0], 1, text.shape[2]]
310
+
311
+ if return_probe_features:
312
+ return x[:, 0]
313
+
314
+ x = self.ln(x)
315
+ return self.f(x[:, 0])
316
+
317
+
318
+ @attr.s(eq=False, repr=False)
319
+ class ImageFeatureExtractor(nn.Module):
320
+ n_state: int = attr.ib()
321
+ n_embd: int = attr.ib()
322
+ device: torch.device = attr.ib(default=torch.device("cuda"))
323
+
324
+ def __attrs_post_init__(self) -> None:
325
+ super().__init__()
326
+
327
+ self.ln = LayerNorm(self.n_state, eps=1e-5, device=self.device)
328
+ self.f = Affine(self.n_state, self.n_embd, use_bias=False, device=self.device)
329
+
330
+ def forward(self, x: torch.Tensor, return_probe_features: bool = False) -> torch.Tensor:
331
+ if return_probe_features:
332
+ return x[:, 0]
333
+
334
+ x = self.ln(x[:, :1])
335
+ return self.f(x[:, 0])
336
+
337
+
338
+ @attr.s(eq=False, repr=False)
339
+ class TextEncoder(nn.Module):
340
+ n_bpe_vocab: int = attr.ib()
341
+ max_text_len: int = attr.ib()
342
+ n_embd: int = attr.ib()
343
+ n_head: int = attr.ib()
344
+ n_xf_blocks: int = attr.ib()
345
+ n_head_state: int = attr.ib(default=64)
346
+ device: torch.device = attr.ib(default=torch.device("cuda"))
347
+ block_size: int = attr.ib(init=False, default=32)
348
+
349
+ def __attrs_post_init__(self) -> None:
350
+ super().__init__()
351
+
352
+ self.n_state = self.n_head * self.n_head_state
353
+ n_rounded_context = self.block_size * int(math.ceil(self.max_text_len / self.block_size))
354
+ n_pad = n_rounded_context - self.max_text_len
355
+
356
+ args = (
357
+ n_rounded_context,
358
+ n_rounded_context,
359
+ self.block_size,
360
+ self.n_head,
361
+ False,
362
+ n_pad,
363
+ n_pad,
364
+ )
365
+ mask = DenseCausalAttentionMask(*args)
366
+ attn_fn = to_attention_info(mask)
367
+
368
+ m = 1 - make_full_layout(mask).astype(np.float32)
369
+ m[m == 1] = -1e10
370
+ attn_fn.pytorch_attn_bias = torch.from_numpy(m).to(self.device)
371
+
372
+ blocks: List[Tuple[str, nn.Module]] = [
373
+ (
374
+ "input",
375
+ TextEmbedding(
376
+ self.n_bpe_vocab, self.max_text_len, self.n_state, device=self.device
377
+ ),
378
+ )
379
+ ]
380
+
381
+ for i in range(self.n_xf_blocks):
382
+ blocks.append(
383
+ (
384
+ f"block_{i}",
385
+ TransformerBlock(self.n_state, 2 * self.n_xf_blocks, attn_fn, self.device),
386
+ )
387
+ )
388
+
389
+ blocks.append(
390
+ ("output", TextFeatureExtractor(self.n_state, self.n_embd, device=self.device))
391
+ )
392
+
393
+ self.blocks = nn.ModuleDict(OrderedDict(blocks))
394
+
395
+ def forward(
396
+ self,
397
+ text: torch.Tensor,
398
+ text_len: torch.Tensor,
399
+ return_probe_features: bool = False,
400
+ ) -> torch.Tensor:
401
+
402
+ n_batch = text.shape[0]
403
+ h = self.blocks["input"](text)
404
+
405
+ for i in range(self.n_xf_blocks):
406
+ h = self.blocks[f"block_{i}"](h)
407
+
408
+ h = self.blocks["output"](h, text_len, return_probe_features=return_probe_features)
409
+
410
+ assert list(h.shape) == [
411
+ n_batch,
412
+ self.n_embd if not return_probe_features else self.n_state,
413
+ ]
414
+ return h
415
+
416
+
417
+ @attr.s(eq=False, repr=False)
418
+ class ImageEncoder(nn.Module):
419
+ image_size: int = attr.ib()
420
+ patch_size: int = attr.ib()
421
+ n_embd: int = attr.ib()
422
+ n_head: int = attr.ib()
423
+ n_xf_blocks: int = attr.ib()
424
+ n_head_state: int = attr.ib(default=64)
425
+ n_timestep: int = attr.ib(default=0)
426
+ device: torch.device = attr.ib(default=torch.device("cuda"))
427
+ block_size: int = attr.ib(init=False, default=32)
428
+
429
+ def __attrs_post_init__(self) -> None:
430
+ super().__init__()
431
+
432
+ self.n_state = self.n_head * self.n_head_state
433
+ self.n_context = 1 + (self.image_size // self.patch_size) ** 2
434
+ n_rounded_context = self.block_size * int(math.ceil(self.n_context / self.block_size))
435
+ n_pad = n_rounded_context - self.n_context
436
+
437
+ args = (
438
+ n_rounded_context,
439
+ n_rounded_context,
440
+ self.block_size,
441
+ self.n_head,
442
+ False,
443
+ n_pad,
444
+ n_pad,
445
+ )
446
+ mask = DenseAttentionMask(*args)
447
+ attn_fn = to_attention_info(mask)
448
+
449
+ m = 1 - make_full_layout(mask).astype(np.float32)
450
+ m[m == 1] = -1e10
451
+ attn_fn.pytorch_attn_bias = torch.from_numpy(m).to(self.device)
452
+
453
+ blocks: List[Tuple[str, nn.Module]] = [
454
+ (
455
+ "input",
456
+ ImageEmbedding(
457
+ self.image_size,
458
+ self.patch_size,
459
+ self.n_state,
460
+ n_timestep=self.n_timestep,
461
+ device=self.device,
462
+ ),
463
+ )
464
+ ]
465
+
466
+ for i in range(self.n_xf_blocks):
467
+ blocks.append(
468
+ (
469
+ f"block_{i}",
470
+ TransformerBlock(self.n_state, 2 * self.n_xf_blocks, attn_fn, self.device),
471
+ )
472
+ )
473
+
474
+ blocks.append(("output", ImageFeatureExtractor(self.n_state, self.n_embd, self.device)))
475
+
476
+ self.blocks = nn.ModuleDict(OrderedDict(blocks))
477
+
478
+ def forward(
479
+ self,
480
+ image: torch.Tensor,
481
+ timesteps: Optional[torch.Tensor] = None,
482
+ return_probe_features: bool = False,
483
+ ) -> torch.Tensor:
484
+ n_batch = image.shape[0]
485
+ h = self.blocks["input"](image, t=timesteps)
486
+
487
+ for i in range(self.n_xf_blocks):
488
+ h = self.blocks[f"block_{i}"](h)
489
+
490
+ h = self.blocks["output"](h, return_probe_features=return_probe_features)
491
+
492
+ assert list(h.shape) == [
493
+ n_batch,
494
+ self.n_embd if not return_probe_features else self.n_state,
495
+ ]
496
+
497
+ return h
glide_text2im/clip/model_creation.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from functools import lru_cache
3
+ from typing import Any, Callable, Dict, List, Optional, Tuple
4
+
5
+ import attr
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import yaml
10
+ from glide_text2im.tokenizer.simple_tokenizer import SimpleTokenizer
11
+
12
+ from .encoders import ImageEncoder, TextEncoder
13
+
14
+
15
+ @lru_cache()
16
+ def default_config_path() -> str:
17
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "config.yaml")
18
+
19
+
20
+ @attr.s
21
+ class CLIPModel:
22
+ config: Dict[str, Any] = attr.ib()
23
+ text_encoder: nn.Module = attr.ib()
24
+ image_encoder: nn.Module = attr.ib()
25
+ logit_scale: torch.Tensor = attr.ib()
26
+ device: torch.device = attr.ib()
27
+ tokenizer: SimpleTokenizer = attr.ib()
28
+
29
+ def encode_prompts(self, prompts: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
30
+ tokens = []
31
+ lens = []
32
+ for prompt in prompts:
33
+ sub_tokens, sub_len = self.tokenizer.padded_tokens_and_len(
34
+ self.tokenizer.encode(prompt), self.text_encoder.max_text_len
35
+ )
36
+ tokens.append(sub_tokens)
37
+ lens.append(sub_len)
38
+ return (
39
+ torch.tensor(tokens).to(dtype=torch.long, device=self.device),
40
+ torch.tensor(lens).to(dtype=torch.long, device=self.device),
41
+ )
42
+
43
+ def text_embeddings(self, prompts: List[str]) -> torch.Tensor:
44
+ tokens, lens = self.encode_prompts(prompts)
45
+ z_t = self.text_encoder(tokens, lens)
46
+ return z_t / (torch.linalg.norm(z_t, dim=-1, keepdim=True) + 1e-12)
47
+
48
+ def image_embeddings(self, images: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
49
+ z_i = self.image_encoder((images + 1) * 127.5, t)
50
+ return z_i / (torch.linalg.norm(z_i, dim=-1, keepdim=True) + 1e-12)
51
+
52
+ def cond_fn(self, prompts: List[str], grad_scale: float) -> Callable[..., torch.Tensor]:
53
+ with torch.no_grad():
54
+ z_t = self.text_embeddings(prompts)
55
+
56
+ def cond_fn(x, t, grad_scale=grad_scale, **kwargs):
57
+ with torch.enable_grad():
58
+ x_var = x.detach().requires_grad_(True)
59
+ z_i = self.image_embeddings(x_var, t)
60
+ loss = torch.exp(self.logit_scale) * (z_t * z_i).sum()
61
+ grad = torch.autograd.grad(loss, x_var)[0].detach()
62
+ return grad * grad_scale
63
+
64
+ return cond_fn
65
+
66
+
67
+ def create_clip_model(
68
+ config_path: Optional[str] = None,
69
+ device: Optional[torch.device] = None,
70
+ tokenizer: Optional[SimpleTokenizer] = None,
71
+ ) -> CLIPModel:
72
+ if config_path is None:
73
+ config_path = default_config_path()
74
+ if device is None:
75
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
76
+ if tokenizer is None:
77
+ tokenizer = SimpleTokenizer()
78
+
79
+ with open(config_path, "r") as f:
80
+ config = yaml.load(f, Loader=yaml.SafeLoader)
81
+
82
+ text_encoder = TextEncoder(
83
+ n_bpe_vocab=config["n_vocab"],
84
+ max_text_len=config["max_text_len"],
85
+ n_embd=config["n_embd"],
86
+ n_head=config["n_head_text"],
87
+ n_xf_blocks=config["n_xf_blocks_text"],
88
+ n_head_state=config["n_head_state_text"],
89
+ device=device,
90
+ )
91
+
92
+ image_encoder = ImageEncoder(
93
+ image_size=config["image_size"],
94
+ patch_size=config["patch_size"],
95
+ n_embd=config["n_embd"],
96
+ n_head=config["n_head_image"],
97
+ n_xf_blocks=config["n_xf_blocks_image"],
98
+ n_head_state=config["n_head_state_image"],
99
+ n_timestep=config["n_timesteps"],
100
+ device=device,
101
+ )
102
+
103
+ logit_scale = torch.tensor(
104
+ np.log(config["logit_scale"]),
105
+ dtype=torch.float32,
106
+ device=device,
107
+ requires_grad=False,
108
+ )
109
+
110
+ return CLIPModel(
111
+ config=config,
112
+ text_encoder=text_encoder,
113
+ image_encoder=image_encoder,
114
+ logit_scale=logit_scale,
115
+ device=device,
116
+ tokenizer=tokenizer,
117
+ )
glide_text2im/clip/utils.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Callable, Optional
3
+
4
+ import attr
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ FilterFn = Callable[[torch.Tensor], torch.Tensor]
10
+
11
+
12
+ class ZeroKeyBiasGrad(torch.autograd.Function):
13
+ @staticmethod
14
+ def forward(ctx, x):
15
+ return x
16
+
17
+ @staticmethod
18
+ def backward(ctx, output_grad):
19
+ output_grad = output_grad.clone()
20
+ output_grad.chunk(3)[1].zero_()
21
+ return output_grad
22
+
23
+
24
+ def zero_key_bias_grad(x: torch.Tensor) -> torch.Tensor:
25
+ return ZeroKeyBiasGrad.apply(x)
26
+
27
+
28
+ @attr.s(eq=False, repr=False)
29
+ class LayerNorm(nn.Module):
30
+ n_state: int = attr.ib()
31
+ eps: float = attr.ib(default=1e-6)
32
+ device: torch.device = attr.ib(default=torch.device("cuda"))
33
+
34
+ def __attrs_post_init__(self) -> None:
35
+ super().__init__()
36
+ self.g = nn.Parameter(torch.ones((self.n_state,), dtype=torch.float32, device=self.device))
37
+ self.b = nn.Parameter(torch.zeros((self.n_state,), dtype=torch.float32, device=self.device))
38
+ self.g.weight_decay_level = "disable" # type: ignore
39
+ self.b.weight_decay_level = "disable" # type: ignore
40
+
41
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
42
+ return F.layer_norm(
43
+ x.type(torch.float32), torch.Size((self.n_state,)), self.g, self.b, self.eps
44
+ )
45
+
46
+
47
+ @attr.s(eq=False, repr=False)
48
+ class Affine(nn.Module):
49
+ n_in: int = attr.ib()
50
+ n_out: int = attr.ib()
51
+ use_bias: bool = attr.ib(default=True)
52
+ use_admnet_init: bool = attr.ib(default=False)
53
+ std: Optional[float] = attr.ib(default=None)
54
+ extra_init_scale: Optional[float] = attr.ib(default=None)
55
+ bias_filter_fn: FilterFn = attr.ib(default=lambda x: x)
56
+ device: torch.device = attr.ib(default=torch.device("cuda"))
57
+
58
+ def __attrs_post_init__(self) -> None:
59
+ super().__init__()
60
+
61
+ if not self.use_admnet_init:
62
+ self.std = self.std if self.std is not None else math.sqrt(2 / (self.n_in + self.n_out))
63
+ self.std = (
64
+ self.std if self.extra_init_scale is None else self.std * self.extra_init_scale
65
+ )
66
+
67
+ w = torch.empty((self.n_out, self.n_in), dtype=torch.float32, device=self.device)
68
+ self.w = nn.Parameter(w)
69
+
70
+ if self.use_bias:
71
+ self.b = nn.Parameter(
72
+ torch.zeros((self.n_out,), dtype=torch.float32, device=self.device)
73
+ )
74
+ self.b.weight_decay_level = "disable" # type: ignore
75
+ else:
76
+ if self.extra_init_scale is not None:
77
+ raise ValueError("extra_init_scale incompatible with admnet init")
78
+
79
+ w = torch.empty((self.n_out, self.n_in), dtype=torch.float32, device=self.device)
80
+
81
+ if self.use_bias:
82
+ b = torch.empty((self.n_out,), dtype=torch.float32, device=self.device)
83
+
84
+ self.w = nn.Parameter(w)
85
+
86
+ if self.use_bias:
87
+ self.b = nn.Parameter(b)
88
+ self.b.weight_decay_level = "disable" # type: ignore
89
+
90
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
91
+ w = self.w if self.w.dtype == x.dtype else self.w.to(x.dtype)
92
+ b = (
93
+ self.bias_filter_fn(self.b if self.b.dtype == x.dtype else self.b.to(x.dtype))
94
+ if self.use_bias
95
+ else None
96
+ )
97
+ return F.linear(x, w, b)