leeyunjai commited on
Commit
e973758
·
1 Parent(s): 7263355

Create transformer.py

Browse files
Files changed (1) hide show
  1. transformer.py +339 -0
transformer.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ import copy
3
+ from typing import Optional, List
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn, Tensor
8
+
9
+
10
+ class Transformer(nn.Module):
11
+
12
+ def __init__(self, config, d_model=512, nhead=8, num_encoder_layers=6,
13
+ num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
14
+ activation="relu", normalize_before=False,
15
+ return_intermediate_dec=False):
16
+ super().__init__()
17
+
18
+ encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
19
+ dropout, activation, normalize_before)
20
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
21
+ self.encoder = TransformerEncoder(
22
+ encoder_layer, num_encoder_layers, encoder_norm)
23
+
24
+ self.embeddings = DecoderEmbeddings(config)
25
+ decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
26
+ dropout, activation, normalize_before)
27
+ decoder_norm = nn.LayerNorm(d_model)
28
+ self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
29
+ return_intermediate=return_intermediate_dec)
30
+
31
+ self._reset_parameters()
32
+
33
+ self.d_model = d_model
34
+ self.nhead = nhead
35
+
36
+ def _reset_parameters(self):
37
+ for p in self.parameters():
38
+ if p.dim() > 1:
39
+ nn.init.xavier_uniform_(p)
40
+
41
+ def forward(self, src, mask, pos_embed, tgt, tgt_mask):
42
+ # flatten NxCxHxW to HWxNxC
43
+ bs, c, h, w = src.shape
44
+ src = src.flatten(2).permute(2, 0, 1)
45
+ pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
46
+ mask = mask.flatten(1)
47
+
48
+ tgt = self.embeddings(tgt).permute(1, 0, 2)
49
+ query_embed = self.embeddings.position_embeddings.weight.unsqueeze(1)
50
+ query_embed = query_embed.repeat(1, bs, 1)
51
+
52
+ memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
53
+ hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, tgt_key_padding_mask=tgt_mask,
54
+ pos=pos_embed, query_pos=query_embed,
55
+ tgt_mask=generate_square_subsequent_mask(len(tgt)).to(tgt.device))
56
+
57
+ return hs
58
+
59
+
60
+ class TransformerEncoder(nn.Module):
61
+
62
+ def __init__(self, encoder_layer, num_layers, norm=None):
63
+ super().__init__()
64
+ self.layers = _get_clones(encoder_layer, num_layers)
65
+ self.num_layers = num_layers
66
+ self.norm = norm
67
+
68
+ def forward(self, src,
69
+ mask: Optional[Tensor] = None,
70
+ src_key_padding_mask: Optional[Tensor] = None,
71
+ pos: Optional[Tensor] = None):
72
+ output = src
73
+
74
+ for layer in self.layers:
75
+ output = layer(output, src_mask=mask,
76
+ src_key_padding_mask=src_key_padding_mask, pos=pos)
77
+
78
+ if self.norm is not None:
79
+ output = self.norm(output)
80
+
81
+ return output
82
+
83
+
84
+ class TransformerDecoder(nn.Module):
85
+
86
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
87
+ super().__init__()
88
+ self.layers = _get_clones(decoder_layer, num_layers)
89
+ self.num_layers = num_layers
90
+ self.norm = norm
91
+ self.return_intermediate = return_intermediate
92
+
93
+ def forward(self, tgt, memory,
94
+ tgt_mask: Optional[Tensor] = None,
95
+ memory_mask: Optional[Tensor] = None,
96
+ tgt_key_padding_mask: Optional[Tensor] = None,
97
+ memory_key_padding_mask: Optional[Tensor] = None,
98
+ pos: Optional[Tensor] = None,
99
+ query_pos: Optional[Tensor] = None):
100
+ output = tgt
101
+
102
+ intermediate = []
103
+
104
+ for layer in self.layers:
105
+ output = layer(output, memory, tgt_mask=tgt_mask,
106
+ memory_mask=memory_mask,
107
+ tgt_key_padding_mask=tgt_key_padding_mask,
108
+ memory_key_padding_mask=memory_key_padding_mask,
109
+ pos=pos, query_pos=query_pos)
110
+ if self.return_intermediate:
111
+ intermediate.append(self.norm(output))
112
+
113
+ if self.norm is not None:
114
+ output = self.norm(output)
115
+ if self.return_intermediate:
116
+ intermediate.pop()
117
+ intermediate.append(output)
118
+
119
+ if self.return_intermediate:
120
+ return torch.stack(intermediate)
121
+
122
+ return output
123
+
124
+
125
+ class TransformerEncoderLayer(nn.Module):
126
+
127
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
128
+ activation="relu", normalize_before=False):
129
+ super().__init__()
130
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
131
+ # Implementation of Feedforward model
132
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
133
+ self.dropout = nn.Dropout(dropout)
134
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
135
+
136
+ self.norm1 = nn.LayerNorm(d_model)
137
+ self.norm2 = nn.LayerNorm(d_model)
138
+ self.dropout1 = nn.Dropout(dropout)
139
+ self.dropout2 = nn.Dropout(dropout)
140
+
141
+ self.activation = _get_activation_fn(activation)
142
+ self.normalize_before = normalize_before
143
+
144
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
145
+ return tensor if pos is None else tensor + pos
146
+
147
+ def forward_post(self,
148
+ src,
149
+ src_mask: Optional[Tensor] = None,
150
+ src_key_padding_mask: Optional[Tensor] = None,
151
+ pos: Optional[Tensor] = None):
152
+ q = k = self.with_pos_embed(src, pos)
153
+ src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
154
+ key_padding_mask=src_key_padding_mask)[0]
155
+ src = src + self.dropout1(src2)
156
+ src = self.norm1(src)
157
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
158
+ src = src + self.dropout2(src2)
159
+ src = self.norm2(src)
160
+ return src
161
+
162
+ def forward_pre(self, src,
163
+ src_mask: Optional[Tensor] = None,
164
+ src_key_padding_mask: Optional[Tensor] = None,
165
+ pos: Optional[Tensor] = None):
166
+ src2 = self.norm1(src)
167
+ q = k = self.with_pos_embed(src2, pos)
168
+ src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
169
+ key_padding_mask=src_key_padding_mask)[0]
170
+ src = src + self.dropout1(src2)
171
+ src2 = self.norm2(src)
172
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
173
+ src = src + self.dropout2(src2)
174
+ return src
175
+
176
+ def forward(self, src,
177
+ src_mask: Optional[Tensor] = None,
178
+ src_key_padding_mask: Optional[Tensor] = None,
179
+ pos: Optional[Tensor] = None):
180
+ if self.normalize_before:
181
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
182
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos)
183
+
184
+
185
+ class TransformerDecoderLayer(nn.Module):
186
+
187
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
188
+ activation="relu", normalize_before=False):
189
+ super().__init__()
190
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
191
+ self.multihead_attn = nn.MultiheadAttention(
192
+ d_model, nhead, dropout=dropout)
193
+ # Implementation of Feedforward model
194
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
195
+ self.dropout = nn.Dropout(dropout)
196
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
197
+
198
+ self.norm1 = nn.LayerNorm(d_model)
199
+ self.norm2 = nn.LayerNorm(d_model)
200
+ self.norm3 = nn.LayerNorm(d_model)
201
+ self.dropout1 = nn.Dropout(dropout)
202
+ self.dropout2 = nn.Dropout(dropout)
203
+ self.dropout3 = nn.Dropout(dropout)
204
+
205
+ self.activation = _get_activation_fn(activation)
206
+ self.normalize_before = normalize_before
207
+
208
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
209
+ return tensor if pos is None else tensor + pos
210
+
211
+ def forward_post(self, tgt, memory,
212
+ tgt_mask: Optional[Tensor] = None,
213
+ memory_mask: Optional[Tensor] = None,
214
+ tgt_key_padding_mask: Optional[Tensor] = None,
215
+ memory_key_padding_mask: Optional[Tensor] = None,
216
+ pos: Optional[Tensor] = None,
217
+ query_pos: Optional[Tensor] = None):
218
+ q = k = self.with_pos_embed(tgt, query_pos)
219
+ tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
220
+ key_padding_mask=tgt_key_padding_mask)[0]
221
+ tgt = tgt + self.dropout1(tgt2)
222
+ tgt = self.norm1(tgt)
223
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
224
+ key=self.with_pos_embed(memory, pos),
225
+ value=memory, attn_mask=memory_mask,
226
+ key_padding_mask=memory_key_padding_mask)[0]
227
+ tgt = tgt + self.dropout2(tgt2)
228
+ tgt = self.norm2(tgt)
229
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
230
+ tgt = tgt + self.dropout3(tgt2)
231
+ tgt = self.norm3(tgt)
232
+ return tgt
233
+
234
+ def forward_pre(self, tgt, memory,
235
+ tgt_mask: Optional[Tensor] = None,
236
+ memory_mask: Optional[Tensor] = None,
237
+ tgt_key_padding_mask: Optional[Tensor] = None,
238
+ memory_key_padding_mask: Optional[Tensor] = None,
239
+ pos: Optional[Tensor] = None,
240
+ query_pos: Optional[Tensor] = None):
241
+ tgt2 = self.norm1(tgt)
242
+ q = k = self.with_pos_embed(tgt2, query_pos)
243
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
244
+ key_padding_mask=tgt_key_padding_mask)[0]
245
+ tgt = tgt + self.dropout1(tgt2)
246
+ tgt2 = self.norm2(tgt)
247
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
248
+ key=self.with_pos_embed(memory, pos),
249
+ value=memory, attn_mask=memory_mask,
250
+ key_padding_mask=memory_key_padding_mask)[0]
251
+ tgt = tgt + self.dropout2(tgt2)
252
+ tgt2 = self.norm3(tgt)
253
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
254
+ tgt = tgt + self.dropout3(tgt2)
255
+ return tgt
256
+
257
+ def forward(self, tgt, memory,
258
+ tgt_mask: Optional[Tensor] = None,
259
+ memory_mask: Optional[Tensor] = None,
260
+ tgt_key_padding_mask: Optional[Tensor] = None,
261
+ memory_key_padding_mask: Optional[Tensor] = None,
262
+ pos: Optional[Tensor] = None,
263
+ query_pos: Optional[Tensor] = None):
264
+ if self.normalize_before:
265
+ return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
266
+ tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
267
+ return self.forward_post(tgt, memory, tgt_mask, memory_mask,
268
+ tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
269
+
270
+
271
+ class DecoderEmbeddings(nn.Module):
272
+ def __init__(self, config):
273
+ super().__init__()
274
+ self.word_embeddings = nn.Embedding(
275
+ config.vocab_size, config.hidden_dim, padding_idx=config.pad_token_id)
276
+ self.position_embeddings = nn.Embedding(
277
+ config.max_position_embeddings, config.hidden_dim
278
+ )
279
+
280
+ self.LayerNorm = torch.nn.LayerNorm(
281
+ config.hidden_dim, eps=config.layer_norm_eps)
282
+ self.dropout = nn.Dropout(config.dropout)
283
+
284
+ def forward(self, x):
285
+ input_shape = x.size()
286
+ seq_length = input_shape[1]
287
+ device = x.device
288
+
289
+ position_ids = torch.arange(
290
+ seq_length, dtype=torch.long, device=device)
291
+ position_ids = position_ids.unsqueeze(0).expand(input_shape)
292
+
293
+ input_embeds = self.word_embeddings(x)
294
+ position_embeds = self.position_embeddings(position_ids)
295
+
296
+ embeddings = input_embeds + position_embeds
297
+ embeddings = self.LayerNorm(embeddings)
298
+ embeddings = self.dropout(embeddings)
299
+
300
+ return embeddings
301
+
302
+
303
+ def _get_clones(module, N):
304
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
305
+
306
+
307
+ def _get_activation_fn(activation):
308
+ """Return an activation function given a string"""
309
+ if activation == "relu":
310
+ return F.relu
311
+ if activation == "gelu":
312
+ return F.gelu
313
+ if activation == "glu":
314
+ return F.glu
315
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
316
+
317
+
318
+ def generate_square_subsequent_mask(sz):
319
+ r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
320
+ Unmasked positions are filled with float(0.0).
321
+ """
322
+ mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
323
+ mask = mask.float().masked_fill(mask == 0, float(
324
+ '-inf')).masked_fill(mask == 1, float(0.0))
325
+ return mask
326
+
327
+
328
+ def build_transformer(config):
329
+ return Transformer(
330
+ config,
331
+ d_model=config.hidden_dim,
332
+ dropout=config.dropout,
333
+ nhead=config.nheads,
334
+ dim_feedforward=config.dim_feedforward,
335
+ num_encoder_layers=config.enc_layers,
336
+ num_decoder_layers=config.dec_layers,
337
+ normalize_before=config.pre_norm,
338
+ return_intermediate_dec=False,
339
+ )