hugohrban commited on
Commit
3528543
·
verified ·
1 Parent(s): bc13d43

Update modeling_progen.py

Browse files
Files changed (1) hide show
  1. modeling_progen.py +35 -145
modeling_progen.py CHANGED
@@ -32,7 +32,6 @@ from transformers.modeling_outputs import (
32
  )
33
  from transformers.modeling_utils import PreTrainedModel
34
  from transformers.utils import logging
35
- from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
36
  from .configuration_progen import ProGenConfig
37
 
38
 
@@ -52,12 +51,11 @@ def fixed_pos_embedding(x, seq_dim=1, seq_len=None):
52
  return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)
53
 
54
 
55
- def rotate_every_two(x):
56
  x1 = x[:, :, :, ::2]
57
  x2 = x[:, :, :, 1::2]
58
  x = torch.stack((-x2, x1), axis=-1)
59
- return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')
60
-
61
 
62
  def apply_rotary_pos_emb(x, sincos, offset=0):
63
  sin, cos = map(
@@ -74,20 +72,21 @@ class ProGenAttention(nn.Module):
74
  def __init__(self, config):
75
  super().__init__()
76
 
77
- max_positions = config.max_position_embeddings
78
  self.register_buffer(
79
  "bias",
80
  torch.tril(
81
  torch.ones((max_positions, max_positions), dtype=torch.bool)
82
  ).view(1, 1, max_positions, max_positions),
 
83
  )
84
- self.register_buffer("masked_bias", torch.tensor(-1e9))
85
 
86
  self.attn_dropout = nn.Dropout(config.attn_pdrop)
87
  self.resid_dropout = nn.Dropout(config.resid_pdrop)
88
 
89
- self.embed_dim = config.hidden_size
90
- self.num_attention_heads = config.num_attention_heads
91
  self.head_dim = self.embed_dim // self.num_attention_heads
92
  if self.head_dim * self.num_attention_heads != self.embed_dim:
93
  raise ValueError(
@@ -103,12 +102,12 @@ class ProGenAttention(nn.Module):
103
  if config.rotary_dim is not None:
104
  self.rotary_dim = config.rotary_dim
105
 
106
- def _split_heads(self, x, n_head, dim_head, mp_num):
107
- reshaped = x.reshape(x.shape[:-1] + (n_head // mp_num, dim_head))
108
- reshaped = reshaped.reshape(x.shape[:-2] + (-1,) + reshaped.shape[-1:])
109
- return reshaped
110
 
111
- def _merge_heads(self, tensor, num_attention_heads, attn_head_size):
112
  """
113
  Merges attn_head_size dim and num_attn_heads dim into n_positions
114
  """
@@ -140,17 +139,17 @@ class ProGenAttention(nn.Module):
140
  # Keep the attention weights computation in fp32 to avoid overflow issues
141
  query = query.to(torch.float32)
142
  key = key.to(torch.float32)
143
- #print("q.shape", query.shape)
144
- #print("k.shape", key.shape)
145
- attn_weights = query @ key.transpose(-1, -2)
146
 
147
  attn_weights = attn_weights / self.scale_attn
 
 
148
  attn_weights = torch.where(
149
  causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)
150
  )
151
 
152
  if attention_mask is not None:
153
- # Apply the attention mask
154
  attn_weights = attn_weights + attention_mask
155
 
156
  attn_weights = F.softmax(attn_weights, dim=-1)
@@ -160,7 +159,7 @@ class ProGenAttention(nn.Module):
160
  if head_mask is not None:
161
  attn_weights = attn_weights * head_mask
162
 
163
- attn_output = attn_weights @ value
164
 
165
  return attn_output, attn_weights
166
 
@@ -173,24 +172,16 @@ class ProGenAttention(nn.Module):
173
  use_cache=False,
174
  output_attentions=False,
175
  ):
176
- qkv = self.qkv_proj(hidden_states)
177
- # TODO(enijkamp): factor out number of logical TPU-v3/v4 cores or make forward pass agnostic
178
- # mp_num = 4
179
  mp_num = 8
180
- qkv_split = qkv.reshape(qkv.shape[:-1] + (mp_num, -1))
181
 
182
- local_dim = self.head_dim * self.num_attention_heads // mp_num
183
- query, value, key = torch.split(qkv_split, local_dim, dim=-1)
184
- query = self._split_heads(
185
- query, self.num_attention_heads, self.head_dim, mp_num=mp_num
186
- )
187
- key = self._split_heads(
188
- key, self.num_attention_heads, self.head_dim, mp_num=mp_num
189
- )
190
 
191
- value = self._split_heads(
192
- value, self.num_attention_heads, self.head_dim, mp_num=mp_num
193
- )
194
  value = value.permute(0, 2, 1, 3)
195
 
196
  seq_len = key.shape[1]
@@ -237,7 +228,7 @@ class ProGenAttention(nn.Module):
237
  query, key, value, attention_mask, head_mask
238
  )
239
 
240
- attn_output = self._merge_heads(
241
  attn_output, self.num_attention_heads, self.head_dim
242
  )
243
 
@@ -256,7 +247,7 @@ class ProGenMLP(nn.Module):
256
  self, intermediate_size, config
257
  ): # in MLP: intermediate_size= 4 * embed_dim
258
  super().__init__()
259
- embed_dim = config.n_embd
260
 
261
  self.fc_in = nn.Linear(embed_dim, intermediate_size)
262
  self.fc_out = nn.Linear(intermediate_size, embed_dim)
@@ -275,8 +266,8 @@ class ProGenMLP(nn.Module):
275
  class ProGenBlock(nn.Module):
276
  def __init__(self, config):
277
  super().__init__()
278
- inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd
279
- self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
280
  self.attn = ProGenAttention(config)
281
  self.mlp = ProGenMLP(inner_dim, config)
282
 
@@ -302,7 +293,7 @@ class ProGenBlock(nn.Module):
302
  attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
303
  outputs = attn_outputs[1:]
304
 
305
- feed_forward_hidden_states = self.mlp(hidden_states)
306
  hidden_states = attn_output + feed_forward_hidden_states + residual
307
 
308
  if use_cache:
@@ -321,7 +312,7 @@ class ProGenPreTrainedModel(PreTrainedModel):
321
 
322
  config_class = ProGenConfig
323
  base_model_prefix = "transformer"
324
- is_parallelizable = True
325
 
326
  def __init__(self, *inputs, **kwargs):
327
  super().__init__(*inputs, **kwargs)
@@ -347,61 +338,16 @@ class ProGenModel(ProGenPreTrainedModel):
347
  def __init__(self, config):
348
  super().__init__(config)
349
  self.vocab_size_emb = config.vocab_size_emb
350
- self.embed_dim = config.n_embd
351
  self.wte = nn.Embedding(config.vocab_size_emb, self.embed_dim)
352
  self.drop = nn.Dropout(config.embd_pdrop)
353
  self.h = nn.ModuleList([ProGenBlock(config) for _ in range(config.n_layer)])
354
  self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
355
  self.rotary_dim = min(
356
- config.rotary_dim, config.n_positions // config.num_attention_heads
357
  )
358
  self.init_weights()
359
 
360
- # Model parallel
361
- self.model_parallel = False
362
- self.device_map = None
363
-
364
- def parallelize(self, device_map=None):
365
- # Check validity of device_map
366
- self.device_map = (
367
- get_device_map(len(self.h), range(torch.cuda.device_count()))
368
- if device_map is None
369
- else device_map
370
- )
371
- assert_device_map(self.device_map, len(self.h))
372
- self.model_parallel = True
373
- self.first_device = (
374
- "cpu"
375
- if "cpu" in self.device_map.keys()
376
- else "cuda:" + str(min(self.device_map.keys()))
377
- )
378
- self.last_device = "cuda:" + str(max(self.device_map.keys()))
379
- self.wte = self.wte.to(self.first_device)
380
- # Load onto devices
381
- for k, v in self.device_map.items():
382
- for block in v:
383
- cuda_device = "cuda:" + str(k)
384
- self.h[block] = self.h[block].to(cuda_device)
385
- # ln_f to last
386
- self.ln_f = self.ln_f.to(self.last_device)
387
-
388
- def deparallelize(self):
389
- self.model_parallel = False
390
- self.device_map = None
391
- self.first_device = "cpu"
392
- self.last_device = "cpu"
393
- self.wte = self.wte.to("cpu")
394
- for index in range(len(self.h)):
395
- self.h[index] = self.h[index].to("cpu")
396
- self.ln_f = self.ln_f.to("cpu")
397
- torch.cuda.empty_cache()
398
-
399
- def get_input_embeddings(self):
400
- return self.wte
401
-
402
- def set_input_embeddings(self, new_embeddings):
403
- self.wte = new_embeddings
404
-
405
  def forward(
406
  self,
407
  input_ids=None,
@@ -510,19 +456,6 @@ class ProGenModel(ProGenPreTrainedModel):
510
  all_self_attentions = () if output_attentions else None
511
  all_hidden_states = () if output_hidden_states else None
512
  for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
513
- # Model parallel
514
- if self.model_parallel:
515
- torch.cuda.set_device(hidden_states.device)
516
- # Ensure layer_past is on same device as hidden_states (might not be correct)
517
- if layer_past is not None:
518
- layer_past = tuple(
519
- past_state.to(hidden_states.device) for past_state in layer_past
520
- )
521
- # Ensure that attention_mask is always on the same device as hidden_states
522
- if attention_mask is not None:
523
- attention_mask = attention_mask.to(hidden_states.device)
524
- if isinstance(head_mask, torch.Tensor):
525
- head_mask = head_mask.to(hidden_states.device)
526
  if output_hidden_states:
527
  all_hidden_states = all_hidden_states + (hidden_states,)
528
 
@@ -567,12 +500,6 @@ class ProGenModel(ProGenPreTrainedModel):
567
  outputs[2 if use_cache else 1],
568
  )
569
 
570
- # Model Parallel: If it's the last layer for that device, put things on the next device
571
- if self.model_parallel:
572
- for k, v in self.device_map.items():
573
- if i == v[-1] and "cuda:" + str(k) != self.last_device:
574
- hidden_states = hidden_states.to("cuda:" + str(k + 1))
575
-
576
  hidden_states = self.ln_f(hidden_states)
577
 
578
  hidden_states = hidden_states.view(*output_shape)
@@ -591,7 +518,7 @@ class ProGenModel(ProGenPreTrainedModel):
591
  ]
592
  if v is not None
593
  )
594
- # print("hidden_states", hidden_states.shape)
595
  return BaseModelOutputWithPast(
596
  last_hidden_state=hidden_states,
597
  past_key_values=presents,
@@ -610,37 +537,9 @@ class ProGenForCausalLM(ProGenPreTrainedModel):
610
  def __init__(self, config):
611
  super().__init__(config)
612
  self.transformer = ProGenModel(config)
613
- self.lm_head = nn.Linear(config.n_embd, config.vocab_size_lm_head)
614
  self.init_weights()
615
 
616
- # Model parallel
617
- self.model_parallel = False
618
- self.device_map = None
619
-
620
- # def parallelize(self, device_map=None):
621
- # self.device_map = (
622
- # get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
623
- # if device_map is None
624
- # else device_map
625
- # )
626
- # assert_device_map(self.device_map, len(self.transformer.h))
627
- # self.transformer.parallelize(self.device_map)
628
- # self.lm_head = self.lm_head.to(self.transformer.first_device)
629
- # self.model_parallel = True
630
-
631
- # def deparallelize(self):
632
- # self.transformer.deparallelize()
633
- # self.transformer = self.transformer.to("cpu")
634
- # self.lm_head = self.lm_head.to("cpu")
635
- # self.model_parallel = False
636
- # torch.cuda.empty_cache()
637
-
638
- # def get_output_embeddings(self):
639
- # return None
640
-
641
- # def set_output_embeddings(self, new_embeddings):
642
- # return
643
-
644
  def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
645
  token_type_ids = kwargs.get("token_type_ids", None)
646
  # only last token for inputs_ids if past is defined in kwargs
@@ -650,7 +549,6 @@ class ProGenForCausalLM(ProGenPreTrainedModel):
650
  token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
651
 
652
  attention_mask = kwargs.get("attention_mask", None)
653
- # print("attention_mask", attention_mask)
654
  position_ids = kwargs.get("position_ids", None)
655
 
656
  if attention_mask is not None and position_ids is None:
@@ -694,8 +592,7 @@ class ProGenForCausalLM(ProGenPreTrainedModel):
694
  return_dict = (
695
  return_dict if return_dict is not None else self.config.use_return_dict
696
  )
697
- # print("here")
698
- # print(attention_mask)
699
  transformer_outputs = self.transformer(
700
  input_ids,
701
  past_key_values=past_key_values,
@@ -711,11 +608,6 @@ class ProGenForCausalLM(ProGenPreTrainedModel):
711
  )
712
  hidden_states = transformer_outputs[0]
713
 
714
- # Set device for model parallelism
715
- if self.model_parallel:
716
- torch.cuda.set_device(self.transformer.first_device)
717
- hidden_states = hidden_states.to(self.lm_head.weight.device)
718
-
719
  # make sure sampling in fp16 works correctly and
720
  # compute loss in fp32 to match with mesh-tf version
721
  # https://github.com/EleutherAI/gpt-neo/blob/89ce74164da2fb16179106f54e2269b5da8db333/models/gpt2/gpt2.py#L179
@@ -726,12 +618,10 @@ class ProGenForCausalLM(ProGenPreTrainedModel):
726
  # Shift so that tokens < n predict n
727
  shift_logits = lm_logits[..., :-1, :].contiguous()
728
  shift_labels = labels[..., 1:].contiguous()
729
- # Flatten the tokens
730
  loss_fct = CrossEntropyLoss()
731
  loss = loss_fct(
732
  shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
733
  )
734
-
735
  loss = loss.to(hidden_states.dtype)
736
 
737
  if not return_dict:
 
32
  )
33
  from transformers.modeling_utils import PreTrainedModel
34
  from transformers.utils import logging
 
35
  from .configuration_progen import ProGenConfig
36
 
37
 
 
51
  return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)
52
 
53
 
54
+ def rotate_every_two(x: torch.Tensor):
55
  x1 = x[:, :, :, ::2]
56
  x2 = x[:, :, :, 1::2]
57
  x = torch.stack((-x2, x1), axis=-1)
58
+ return x.flatten(-2)
 
59
 
60
  def apply_rotary_pos_emb(x, sincos, offset=0):
61
  sin, cos = map(
 
72
  def __init__(self, config):
73
  super().__init__()
74
 
75
+ max_positions = config.n_positions
76
  self.register_buffer(
77
  "bias",
78
  torch.tril(
79
  torch.ones((max_positions, max_positions), dtype=torch.bool)
80
  ).view(1, 1, max_positions, max_positions),
81
+ persistent=False
82
  )
83
+ self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False) # approx. -inf
84
 
85
  self.attn_dropout = nn.Dropout(config.attn_pdrop)
86
  self.resid_dropout = nn.Dropout(config.resid_pdrop)
87
 
88
+ self.embed_dim = config.embed_dim
89
+ self.num_attention_heads = config.n_head
90
  self.head_dim = self.embed_dim // self.num_attention_heads
91
  if self.head_dim * self.num_attention_heads != self.embed_dim:
92
  raise ValueError(
 
102
  if config.rotary_dim is not None:
103
  self.rotary_dim = config.rotary_dim
104
 
105
+ def _split_heads(self, x: torch.Tensor, n_head, dim_head) -> torch.Tensor:
106
+ x = x.reshape(x.shape[:-2] + (-1,)) # (B, T, 8 * E // 8)
107
+ x = x.reshape(x.shape[:-1] + (n_head, dim_head)) # (B, T, n_heads, dim_head)
108
+ return x
109
 
110
+ def _merge_heads(self, tensor, num_attention_heads, attn_head_size) -> torch.Tensor:
111
  """
112
  Merges attn_head_size dim and num_attn_heads dim into n_positions
113
  """
 
139
  # Keep the attention weights computation in fp32 to avoid overflow issues
140
  query = query.to(torch.float32)
141
  key = key.to(torch.float32)
142
+
143
+ attn_weights = query @ key.transpose(-1, -2) # (B, n_heads, T, T)
 
144
 
145
  attn_weights = attn_weights / self.scale_attn
146
+
147
+ # attend only to previous positions
148
  attn_weights = torch.where(
149
  causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)
150
  )
151
 
152
  if attention_mask is not None:
 
153
  attn_weights = attn_weights + attention_mask
154
 
155
  attn_weights = F.softmax(attn_weights, dim=-1)
 
159
  if head_mask is not None:
160
  attn_weights = attn_weights * head_mask
161
 
162
+ attn_output = attn_weights @ value # (B, n_heads, T, dim_head)
163
 
164
  return attn_output, attn_weights
165
 
 
172
  use_cache=False,
173
  output_attentions=False,
174
  ):
175
+ qkv = self.qkv_proj(hidden_states) # (B, T, 3 * E)
176
+
 
177
  mp_num = 8
178
+ qkv_split = qkv.reshape(qkv.shape[:-1] + (mp_num, -1)) # (B, T, 8, 3 * E // 8)
179
 
180
+ query, value, key = torch.split(qkv_split, self.embed_dim // mp_num, dim=-1) # 3 * (B, T, 8, E // 8)
 
 
 
 
 
 
 
181
 
182
+ query = self._split_heads(query, self.num_attention_heads, self.head_dim) # (B, T, n_heads, dim_head)
183
+ key = self._split_heads(key, self.num_attention_heads, self.head_dim) # (B, T, n_heads, dim_head)
184
+ value = self._split_heads(value, self.num_attention_heads, self.head_dim) # (B, T, n_heads, dim_head)
185
  value = value.permute(0, 2, 1, 3)
186
 
187
  seq_len = key.shape[1]
 
228
  query, key, value, attention_mask, head_mask
229
  )
230
 
231
+ attn_output = self._merge_heads( # (B, T, E)
232
  attn_output, self.num_attention_heads, self.head_dim
233
  )
234
 
 
247
  self, intermediate_size, config
248
  ): # in MLP: intermediate_size= 4 * embed_dim
249
  super().__init__()
250
+ embed_dim = config.embed_dim
251
 
252
  self.fc_in = nn.Linear(embed_dim, intermediate_size)
253
  self.fc_out = nn.Linear(intermediate_size, embed_dim)
 
266
  class ProGenBlock(nn.Module):
267
  def __init__(self, config):
268
  super().__init__()
269
+ inner_dim = config.n_inner if config.n_inner is not None else 4 * config.embed_dim
270
+ self.ln_1 = nn.LayerNorm(config.embed_dim, eps=config.layer_norm_epsilon)
271
  self.attn = ProGenAttention(config)
272
  self.mlp = ProGenMLP(inner_dim, config)
273
 
 
293
  attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
294
  outputs = attn_outputs[1:]
295
 
296
+ feed_forward_hidden_states = self.mlp(hidden_states) # (B, T, E)
297
  hidden_states = attn_output + feed_forward_hidden_states + residual
298
 
299
  if use_cache:
 
312
 
313
  config_class = ProGenConfig
314
  base_model_prefix = "transformer"
315
+ is_parallelizable = False
316
 
317
  def __init__(self, *inputs, **kwargs):
318
  super().__init__(*inputs, **kwargs)
 
338
  def __init__(self, config):
339
  super().__init__(config)
340
  self.vocab_size_emb = config.vocab_size_emb
341
+ self.embed_dim = config.embed_dim
342
  self.wte = nn.Embedding(config.vocab_size_emb, self.embed_dim)
343
  self.drop = nn.Dropout(config.embd_pdrop)
344
  self.h = nn.ModuleList([ProGenBlock(config) for _ in range(config.n_layer)])
345
  self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
346
  self.rotary_dim = min(
347
+ config.rotary_dim, config.n_positions // config.n_head
348
  )
349
  self.init_weights()
350
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
  def forward(
352
  self,
353
  input_ids=None,
 
456
  all_self_attentions = () if output_attentions else None
457
  all_hidden_states = () if output_hidden_states else None
458
  for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
 
 
 
 
 
 
 
 
 
 
 
 
 
459
  if output_hidden_states:
460
  all_hidden_states = all_hidden_states + (hidden_states,)
461
 
 
500
  outputs[2 if use_cache else 1],
501
  )
502
 
 
 
 
 
 
 
503
  hidden_states = self.ln_f(hidden_states)
504
 
505
  hidden_states = hidden_states.view(*output_shape)
 
518
  ]
519
  if v is not None
520
  )
521
+
522
  return BaseModelOutputWithPast(
523
  last_hidden_state=hidden_states,
524
  past_key_values=presents,
 
537
  def __init__(self, config):
538
  super().__init__(config)
539
  self.transformer = ProGenModel(config)
540
+ self.lm_head = nn.Linear(config.embed_dim, config.vocab_size_lm_head)
541
  self.init_weights()
542
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
543
  def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
544
  token_type_ids = kwargs.get("token_type_ids", None)
545
  # only last token for inputs_ids if past is defined in kwargs
 
549
  token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
550
 
551
  attention_mask = kwargs.get("attention_mask", None)
 
552
  position_ids = kwargs.get("position_ids", None)
553
 
554
  if attention_mask is not None and position_ids is None:
 
592
  return_dict = (
593
  return_dict if return_dict is not None else self.config.use_return_dict
594
  )
595
+
 
596
  transformer_outputs = self.transformer(
597
  input_ids,
598
  past_key_values=past_key_values,
 
608
  )
609
  hidden_states = transformer_outputs[0]
610
 
 
 
 
 
 
611
  # make sure sampling in fp16 works correctly and
612
  # compute loss in fp32 to match with mesh-tf version
613
  # https://github.com/EleutherAI/gpt-neo/blob/89ce74164da2fb16179106f54e2269b5da8db333/models/gpt2/gpt2.py#L179
 
618
  # Shift so that tokens < n predict n
619
  shift_logits = lm_logits[..., :-1, :].contiguous()
620
  shift_labels = labels[..., 1:].contiguous()
 
621
  loss_fct = CrossEntropyLoss()
622
  loss = loss_fct(
623
  shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
624
  )
 
625
  loss = loss.to(hidden_states.dtype)
626
 
627
  if not return_dict: