Update modeling_progen.py
Browse files- modeling_progen.py +35 -145
modeling_progen.py
CHANGED
@@ -32,7 +32,6 @@ from transformers.modeling_outputs import (
|
|
32 |
)
|
33 |
from transformers.modeling_utils import PreTrainedModel
|
34 |
from transformers.utils import logging
|
35 |
-
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
|
36 |
from .configuration_progen import ProGenConfig
|
37 |
|
38 |
|
@@ -52,12 +51,11 @@ def fixed_pos_embedding(x, seq_dim=1, seq_len=None):
|
|
52 |
return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)
|
53 |
|
54 |
|
55 |
-
def rotate_every_two(x):
|
56 |
x1 = x[:, :, :, ::2]
|
57 |
x2 = x[:, :, :, 1::2]
|
58 |
x = torch.stack((-x2, x1), axis=-1)
|
59 |
-
return x.flatten(-2)
|
60 |
-
|
61 |
|
62 |
def apply_rotary_pos_emb(x, sincos, offset=0):
|
63 |
sin, cos = map(
|
@@ -74,20 +72,21 @@ class ProGenAttention(nn.Module):
|
|
74 |
def __init__(self, config):
|
75 |
super().__init__()
|
76 |
|
77 |
-
max_positions = config.
|
78 |
self.register_buffer(
|
79 |
"bias",
|
80 |
torch.tril(
|
81 |
torch.ones((max_positions, max_positions), dtype=torch.bool)
|
82 |
).view(1, 1, max_positions, max_positions),
|
|
|
83 |
)
|
84 |
-
self.register_buffer("masked_bias", torch.tensor(-1e9))
|
85 |
|
86 |
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
87 |
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
88 |
|
89 |
-
self.embed_dim = config.
|
90 |
-
self.num_attention_heads = config.
|
91 |
self.head_dim = self.embed_dim // self.num_attention_heads
|
92 |
if self.head_dim * self.num_attention_heads != self.embed_dim:
|
93 |
raise ValueError(
|
@@ -103,12 +102,12 @@ class ProGenAttention(nn.Module):
|
|
103 |
if config.rotary_dim is not None:
|
104 |
self.rotary_dim = config.rotary_dim
|
105 |
|
106 |
-
def _split_heads(self, x, n_head, dim_head
|
107 |
-
|
108 |
-
|
109 |
-
return
|
110 |
|
111 |
-
def _merge_heads(self, tensor, num_attention_heads, attn_head_size):
|
112 |
"""
|
113 |
Merges attn_head_size dim and num_attn_heads dim into n_positions
|
114 |
"""
|
@@ -140,17 +139,17 @@ class ProGenAttention(nn.Module):
|
|
140 |
# Keep the attention weights computation in fp32 to avoid overflow issues
|
141 |
query = query.to(torch.float32)
|
142 |
key = key.to(torch.float32)
|
143 |
-
|
144 |
-
#
|
145 |
-
attn_weights = query @ key.transpose(-1, -2)
|
146 |
|
147 |
attn_weights = attn_weights / self.scale_attn
|
|
|
|
|
148 |
attn_weights = torch.where(
|
149 |
causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)
|
150 |
)
|
151 |
|
152 |
if attention_mask is not None:
|
153 |
-
# Apply the attention mask
|
154 |
attn_weights = attn_weights + attention_mask
|
155 |
|
156 |
attn_weights = F.softmax(attn_weights, dim=-1)
|
@@ -160,7 +159,7 @@ class ProGenAttention(nn.Module):
|
|
160 |
if head_mask is not None:
|
161 |
attn_weights = attn_weights * head_mask
|
162 |
|
163 |
-
attn_output = attn_weights @ value
|
164 |
|
165 |
return attn_output, attn_weights
|
166 |
|
@@ -173,24 +172,16 @@ class ProGenAttention(nn.Module):
|
|
173 |
use_cache=False,
|
174 |
output_attentions=False,
|
175 |
):
|
176 |
-
qkv = self.qkv_proj(hidden_states)
|
177 |
-
|
178 |
-
# mp_num = 4
|
179 |
mp_num = 8
|
180 |
-
qkv_split = qkv.reshape(qkv.shape[:-1] + (mp_num, -1))
|
181 |
|
182 |
-
|
183 |
-
query, value, key = torch.split(qkv_split, local_dim, dim=-1)
|
184 |
-
query = self._split_heads(
|
185 |
-
query, self.num_attention_heads, self.head_dim, mp_num=mp_num
|
186 |
-
)
|
187 |
-
key = self._split_heads(
|
188 |
-
key, self.num_attention_heads, self.head_dim, mp_num=mp_num
|
189 |
-
)
|
190 |
|
191 |
-
|
192 |
-
|
193 |
-
)
|
194 |
value = value.permute(0, 2, 1, 3)
|
195 |
|
196 |
seq_len = key.shape[1]
|
@@ -237,7 +228,7 @@ class ProGenAttention(nn.Module):
|
|
237 |
query, key, value, attention_mask, head_mask
|
238 |
)
|
239 |
|
240 |
-
attn_output = self._merge_heads(
|
241 |
attn_output, self.num_attention_heads, self.head_dim
|
242 |
)
|
243 |
|
@@ -256,7 +247,7 @@ class ProGenMLP(nn.Module):
|
|
256 |
self, intermediate_size, config
|
257 |
): # in MLP: intermediate_size= 4 * embed_dim
|
258 |
super().__init__()
|
259 |
-
embed_dim = config.
|
260 |
|
261 |
self.fc_in = nn.Linear(embed_dim, intermediate_size)
|
262 |
self.fc_out = nn.Linear(intermediate_size, embed_dim)
|
@@ -275,8 +266,8 @@ class ProGenMLP(nn.Module):
|
|
275 |
class ProGenBlock(nn.Module):
|
276 |
def __init__(self, config):
|
277 |
super().__init__()
|
278 |
-
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.
|
279 |
-
self.ln_1 = nn.LayerNorm(config.
|
280 |
self.attn = ProGenAttention(config)
|
281 |
self.mlp = ProGenMLP(inner_dim, config)
|
282 |
|
@@ -302,7 +293,7 @@ class ProGenBlock(nn.Module):
|
|
302 |
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
|
303 |
outputs = attn_outputs[1:]
|
304 |
|
305 |
-
feed_forward_hidden_states = self.mlp(hidden_states)
|
306 |
hidden_states = attn_output + feed_forward_hidden_states + residual
|
307 |
|
308 |
if use_cache:
|
@@ -321,7 +312,7 @@ class ProGenPreTrainedModel(PreTrainedModel):
|
|
321 |
|
322 |
config_class = ProGenConfig
|
323 |
base_model_prefix = "transformer"
|
324 |
-
is_parallelizable =
|
325 |
|
326 |
def __init__(self, *inputs, **kwargs):
|
327 |
super().__init__(*inputs, **kwargs)
|
@@ -347,61 +338,16 @@ class ProGenModel(ProGenPreTrainedModel):
|
|
347 |
def __init__(self, config):
|
348 |
super().__init__(config)
|
349 |
self.vocab_size_emb = config.vocab_size_emb
|
350 |
-
self.embed_dim = config.
|
351 |
self.wte = nn.Embedding(config.vocab_size_emb, self.embed_dim)
|
352 |
self.drop = nn.Dropout(config.embd_pdrop)
|
353 |
self.h = nn.ModuleList([ProGenBlock(config) for _ in range(config.n_layer)])
|
354 |
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
355 |
self.rotary_dim = min(
|
356 |
-
config.rotary_dim, config.n_positions // config.
|
357 |
)
|
358 |
self.init_weights()
|
359 |
|
360 |
-
# Model parallel
|
361 |
-
self.model_parallel = False
|
362 |
-
self.device_map = None
|
363 |
-
|
364 |
-
def parallelize(self, device_map=None):
|
365 |
-
# Check validity of device_map
|
366 |
-
self.device_map = (
|
367 |
-
get_device_map(len(self.h), range(torch.cuda.device_count()))
|
368 |
-
if device_map is None
|
369 |
-
else device_map
|
370 |
-
)
|
371 |
-
assert_device_map(self.device_map, len(self.h))
|
372 |
-
self.model_parallel = True
|
373 |
-
self.first_device = (
|
374 |
-
"cpu"
|
375 |
-
if "cpu" in self.device_map.keys()
|
376 |
-
else "cuda:" + str(min(self.device_map.keys()))
|
377 |
-
)
|
378 |
-
self.last_device = "cuda:" + str(max(self.device_map.keys()))
|
379 |
-
self.wte = self.wte.to(self.first_device)
|
380 |
-
# Load onto devices
|
381 |
-
for k, v in self.device_map.items():
|
382 |
-
for block in v:
|
383 |
-
cuda_device = "cuda:" + str(k)
|
384 |
-
self.h[block] = self.h[block].to(cuda_device)
|
385 |
-
# ln_f to last
|
386 |
-
self.ln_f = self.ln_f.to(self.last_device)
|
387 |
-
|
388 |
-
def deparallelize(self):
|
389 |
-
self.model_parallel = False
|
390 |
-
self.device_map = None
|
391 |
-
self.first_device = "cpu"
|
392 |
-
self.last_device = "cpu"
|
393 |
-
self.wte = self.wte.to("cpu")
|
394 |
-
for index in range(len(self.h)):
|
395 |
-
self.h[index] = self.h[index].to("cpu")
|
396 |
-
self.ln_f = self.ln_f.to("cpu")
|
397 |
-
torch.cuda.empty_cache()
|
398 |
-
|
399 |
-
def get_input_embeddings(self):
|
400 |
-
return self.wte
|
401 |
-
|
402 |
-
def set_input_embeddings(self, new_embeddings):
|
403 |
-
self.wte = new_embeddings
|
404 |
-
|
405 |
def forward(
|
406 |
self,
|
407 |
input_ids=None,
|
@@ -510,19 +456,6 @@ class ProGenModel(ProGenPreTrainedModel):
|
|
510 |
all_self_attentions = () if output_attentions else None
|
511 |
all_hidden_states = () if output_hidden_states else None
|
512 |
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
513 |
-
# Model parallel
|
514 |
-
if self.model_parallel:
|
515 |
-
torch.cuda.set_device(hidden_states.device)
|
516 |
-
# Ensure layer_past is on same device as hidden_states (might not be correct)
|
517 |
-
if layer_past is not None:
|
518 |
-
layer_past = tuple(
|
519 |
-
past_state.to(hidden_states.device) for past_state in layer_past
|
520 |
-
)
|
521 |
-
# Ensure that attention_mask is always on the same device as hidden_states
|
522 |
-
if attention_mask is not None:
|
523 |
-
attention_mask = attention_mask.to(hidden_states.device)
|
524 |
-
if isinstance(head_mask, torch.Tensor):
|
525 |
-
head_mask = head_mask.to(hidden_states.device)
|
526 |
if output_hidden_states:
|
527 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
528 |
|
@@ -567,12 +500,6 @@ class ProGenModel(ProGenPreTrainedModel):
|
|
567 |
outputs[2 if use_cache else 1],
|
568 |
)
|
569 |
|
570 |
-
# Model Parallel: If it's the last layer for that device, put things on the next device
|
571 |
-
if self.model_parallel:
|
572 |
-
for k, v in self.device_map.items():
|
573 |
-
if i == v[-1] and "cuda:" + str(k) != self.last_device:
|
574 |
-
hidden_states = hidden_states.to("cuda:" + str(k + 1))
|
575 |
-
|
576 |
hidden_states = self.ln_f(hidden_states)
|
577 |
|
578 |
hidden_states = hidden_states.view(*output_shape)
|
@@ -591,7 +518,7 @@ class ProGenModel(ProGenPreTrainedModel):
|
|
591 |
]
|
592 |
if v is not None
|
593 |
)
|
594 |
-
|
595 |
return BaseModelOutputWithPast(
|
596 |
last_hidden_state=hidden_states,
|
597 |
past_key_values=presents,
|
@@ -610,37 +537,9 @@ class ProGenForCausalLM(ProGenPreTrainedModel):
|
|
610 |
def __init__(self, config):
|
611 |
super().__init__(config)
|
612 |
self.transformer = ProGenModel(config)
|
613 |
-
self.lm_head = nn.Linear(config.
|
614 |
self.init_weights()
|
615 |
|
616 |
-
# Model parallel
|
617 |
-
self.model_parallel = False
|
618 |
-
self.device_map = None
|
619 |
-
|
620 |
-
# def parallelize(self, device_map=None):
|
621 |
-
# self.device_map = (
|
622 |
-
# get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
|
623 |
-
# if device_map is None
|
624 |
-
# else device_map
|
625 |
-
# )
|
626 |
-
# assert_device_map(self.device_map, len(self.transformer.h))
|
627 |
-
# self.transformer.parallelize(self.device_map)
|
628 |
-
# self.lm_head = self.lm_head.to(self.transformer.first_device)
|
629 |
-
# self.model_parallel = True
|
630 |
-
|
631 |
-
# def deparallelize(self):
|
632 |
-
# self.transformer.deparallelize()
|
633 |
-
# self.transformer = self.transformer.to("cpu")
|
634 |
-
# self.lm_head = self.lm_head.to("cpu")
|
635 |
-
# self.model_parallel = False
|
636 |
-
# torch.cuda.empty_cache()
|
637 |
-
|
638 |
-
# def get_output_embeddings(self):
|
639 |
-
# return None
|
640 |
-
|
641 |
-
# def set_output_embeddings(self, new_embeddings):
|
642 |
-
# return
|
643 |
-
|
644 |
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
645 |
token_type_ids = kwargs.get("token_type_ids", None)
|
646 |
# only last token for inputs_ids if past is defined in kwargs
|
@@ -650,7 +549,6 @@ class ProGenForCausalLM(ProGenPreTrainedModel):
|
|
650 |
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
|
651 |
|
652 |
attention_mask = kwargs.get("attention_mask", None)
|
653 |
-
# print("attention_mask", attention_mask)
|
654 |
position_ids = kwargs.get("position_ids", None)
|
655 |
|
656 |
if attention_mask is not None and position_ids is None:
|
@@ -694,8 +592,7 @@ class ProGenForCausalLM(ProGenPreTrainedModel):
|
|
694 |
return_dict = (
|
695 |
return_dict if return_dict is not None else self.config.use_return_dict
|
696 |
)
|
697 |
-
|
698 |
-
# print(attention_mask)
|
699 |
transformer_outputs = self.transformer(
|
700 |
input_ids,
|
701 |
past_key_values=past_key_values,
|
@@ -711,11 +608,6 @@ class ProGenForCausalLM(ProGenPreTrainedModel):
|
|
711 |
)
|
712 |
hidden_states = transformer_outputs[0]
|
713 |
|
714 |
-
# Set device for model parallelism
|
715 |
-
if self.model_parallel:
|
716 |
-
torch.cuda.set_device(self.transformer.first_device)
|
717 |
-
hidden_states = hidden_states.to(self.lm_head.weight.device)
|
718 |
-
|
719 |
# make sure sampling in fp16 works correctly and
|
720 |
# compute loss in fp32 to match with mesh-tf version
|
721 |
# https://github.com/EleutherAI/gpt-neo/blob/89ce74164da2fb16179106f54e2269b5da8db333/models/gpt2/gpt2.py#L179
|
@@ -726,12 +618,10 @@ class ProGenForCausalLM(ProGenPreTrainedModel):
|
|
726 |
# Shift so that tokens < n predict n
|
727 |
shift_logits = lm_logits[..., :-1, :].contiguous()
|
728 |
shift_labels = labels[..., 1:].contiguous()
|
729 |
-
# Flatten the tokens
|
730 |
loss_fct = CrossEntropyLoss()
|
731 |
loss = loss_fct(
|
732 |
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
|
733 |
)
|
734 |
-
|
735 |
loss = loss.to(hidden_states.dtype)
|
736 |
|
737 |
if not return_dict:
|
|
|
32 |
)
|
33 |
from transformers.modeling_utils import PreTrainedModel
|
34 |
from transformers.utils import logging
|
|
|
35 |
from .configuration_progen import ProGenConfig
|
36 |
|
37 |
|
|
|
51 |
return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)
|
52 |
|
53 |
|
54 |
+
def rotate_every_two(x: torch.Tensor):
|
55 |
x1 = x[:, :, :, ::2]
|
56 |
x2 = x[:, :, :, 1::2]
|
57 |
x = torch.stack((-x2, x1), axis=-1)
|
58 |
+
return x.flatten(-2)
|
|
|
59 |
|
60 |
def apply_rotary_pos_emb(x, sincos, offset=0):
|
61 |
sin, cos = map(
|
|
|
72 |
def __init__(self, config):
|
73 |
super().__init__()
|
74 |
|
75 |
+
max_positions = config.n_positions
|
76 |
self.register_buffer(
|
77 |
"bias",
|
78 |
torch.tril(
|
79 |
torch.ones((max_positions, max_positions), dtype=torch.bool)
|
80 |
).view(1, 1, max_positions, max_positions),
|
81 |
+
persistent=False
|
82 |
)
|
83 |
+
self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False) # approx. -inf
|
84 |
|
85 |
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
86 |
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
87 |
|
88 |
+
self.embed_dim = config.embed_dim
|
89 |
+
self.num_attention_heads = config.n_head
|
90 |
self.head_dim = self.embed_dim // self.num_attention_heads
|
91 |
if self.head_dim * self.num_attention_heads != self.embed_dim:
|
92 |
raise ValueError(
|
|
|
102 |
if config.rotary_dim is not None:
|
103 |
self.rotary_dim = config.rotary_dim
|
104 |
|
105 |
+
def _split_heads(self, x: torch.Tensor, n_head, dim_head) -> torch.Tensor:
|
106 |
+
x = x.reshape(x.shape[:-2] + (-1,)) # (B, T, 8 * E // 8)
|
107 |
+
x = x.reshape(x.shape[:-1] + (n_head, dim_head)) # (B, T, n_heads, dim_head)
|
108 |
+
return x
|
109 |
|
110 |
+
def _merge_heads(self, tensor, num_attention_heads, attn_head_size) -> torch.Tensor:
|
111 |
"""
|
112 |
Merges attn_head_size dim and num_attn_heads dim into n_positions
|
113 |
"""
|
|
|
139 |
# Keep the attention weights computation in fp32 to avoid overflow issues
|
140 |
query = query.to(torch.float32)
|
141 |
key = key.to(torch.float32)
|
142 |
+
|
143 |
+
attn_weights = query @ key.transpose(-1, -2) # (B, n_heads, T, T)
|
|
|
144 |
|
145 |
attn_weights = attn_weights / self.scale_attn
|
146 |
+
|
147 |
+
# attend only to previous positions
|
148 |
attn_weights = torch.where(
|
149 |
causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)
|
150 |
)
|
151 |
|
152 |
if attention_mask is not None:
|
|
|
153 |
attn_weights = attn_weights + attention_mask
|
154 |
|
155 |
attn_weights = F.softmax(attn_weights, dim=-1)
|
|
|
159 |
if head_mask is not None:
|
160 |
attn_weights = attn_weights * head_mask
|
161 |
|
162 |
+
attn_output = attn_weights @ value # (B, n_heads, T, dim_head)
|
163 |
|
164 |
return attn_output, attn_weights
|
165 |
|
|
|
172 |
use_cache=False,
|
173 |
output_attentions=False,
|
174 |
):
|
175 |
+
qkv = self.qkv_proj(hidden_states) # (B, T, 3 * E)
|
176 |
+
|
|
|
177 |
mp_num = 8
|
178 |
+
qkv_split = qkv.reshape(qkv.shape[:-1] + (mp_num, -1)) # (B, T, 8, 3 * E // 8)
|
179 |
|
180 |
+
query, value, key = torch.split(qkv_split, self.embed_dim // mp_num, dim=-1) # 3 * (B, T, 8, E // 8)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
|
182 |
+
query = self._split_heads(query, self.num_attention_heads, self.head_dim) # (B, T, n_heads, dim_head)
|
183 |
+
key = self._split_heads(key, self.num_attention_heads, self.head_dim) # (B, T, n_heads, dim_head)
|
184 |
+
value = self._split_heads(value, self.num_attention_heads, self.head_dim) # (B, T, n_heads, dim_head)
|
185 |
value = value.permute(0, 2, 1, 3)
|
186 |
|
187 |
seq_len = key.shape[1]
|
|
|
228 |
query, key, value, attention_mask, head_mask
|
229 |
)
|
230 |
|
231 |
+
attn_output = self._merge_heads( # (B, T, E)
|
232 |
attn_output, self.num_attention_heads, self.head_dim
|
233 |
)
|
234 |
|
|
|
247 |
self, intermediate_size, config
|
248 |
): # in MLP: intermediate_size= 4 * embed_dim
|
249 |
super().__init__()
|
250 |
+
embed_dim = config.embed_dim
|
251 |
|
252 |
self.fc_in = nn.Linear(embed_dim, intermediate_size)
|
253 |
self.fc_out = nn.Linear(intermediate_size, embed_dim)
|
|
|
266 |
class ProGenBlock(nn.Module):
|
267 |
def __init__(self, config):
|
268 |
super().__init__()
|
269 |
+
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.embed_dim
|
270 |
+
self.ln_1 = nn.LayerNorm(config.embed_dim, eps=config.layer_norm_epsilon)
|
271 |
self.attn = ProGenAttention(config)
|
272 |
self.mlp = ProGenMLP(inner_dim, config)
|
273 |
|
|
|
293 |
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
|
294 |
outputs = attn_outputs[1:]
|
295 |
|
296 |
+
feed_forward_hidden_states = self.mlp(hidden_states) # (B, T, E)
|
297 |
hidden_states = attn_output + feed_forward_hidden_states + residual
|
298 |
|
299 |
if use_cache:
|
|
|
312 |
|
313 |
config_class = ProGenConfig
|
314 |
base_model_prefix = "transformer"
|
315 |
+
is_parallelizable = False
|
316 |
|
317 |
def __init__(self, *inputs, **kwargs):
|
318 |
super().__init__(*inputs, **kwargs)
|
|
|
338 |
def __init__(self, config):
|
339 |
super().__init__(config)
|
340 |
self.vocab_size_emb = config.vocab_size_emb
|
341 |
+
self.embed_dim = config.embed_dim
|
342 |
self.wte = nn.Embedding(config.vocab_size_emb, self.embed_dim)
|
343 |
self.drop = nn.Dropout(config.embd_pdrop)
|
344 |
self.h = nn.ModuleList([ProGenBlock(config) for _ in range(config.n_layer)])
|
345 |
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
346 |
self.rotary_dim = min(
|
347 |
+
config.rotary_dim, config.n_positions // config.n_head
|
348 |
)
|
349 |
self.init_weights()
|
350 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
351 |
def forward(
|
352 |
self,
|
353 |
input_ids=None,
|
|
|
456 |
all_self_attentions = () if output_attentions else None
|
457 |
all_hidden_states = () if output_hidden_states else None
|
458 |
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
459 |
if output_hidden_states:
|
460 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
461 |
|
|
|
500 |
outputs[2 if use_cache else 1],
|
501 |
)
|
502 |
|
|
|
|
|
|
|
|
|
|
|
|
|
503 |
hidden_states = self.ln_f(hidden_states)
|
504 |
|
505 |
hidden_states = hidden_states.view(*output_shape)
|
|
|
518 |
]
|
519 |
if v is not None
|
520 |
)
|
521 |
+
|
522 |
return BaseModelOutputWithPast(
|
523 |
last_hidden_state=hidden_states,
|
524 |
past_key_values=presents,
|
|
|
537 |
def __init__(self, config):
|
538 |
super().__init__(config)
|
539 |
self.transformer = ProGenModel(config)
|
540 |
+
self.lm_head = nn.Linear(config.embed_dim, config.vocab_size_lm_head)
|
541 |
self.init_weights()
|
542 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
543 |
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
544 |
token_type_ids = kwargs.get("token_type_ids", None)
|
545 |
# only last token for inputs_ids if past is defined in kwargs
|
|
|
549 |
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
|
550 |
|
551 |
attention_mask = kwargs.get("attention_mask", None)
|
|
|
552 |
position_ids = kwargs.get("position_ids", None)
|
553 |
|
554 |
if attention_mask is not None and position_ids is None:
|
|
|
592 |
return_dict = (
|
593 |
return_dict if return_dict is not None else self.config.use_return_dict
|
594 |
)
|
595 |
+
|
|
|
596 |
transformer_outputs = self.transformer(
|
597 |
input_ids,
|
598 |
past_key_values=past_key_values,
|
|
|
608 |
)
|
609 |
hidden_states = transformer_outputs[0]
|
610 |
|
|
|
|
|
|
|
|
|
|
|
611 |
# make sure sampling in fp16 works correctly and
|
612 |
# compute loss in fp32 to match with mesh-tf version
|
613 |
# https://github.com/EleutherAI/gpt-neo/blob/89ce74164da2fb16179106f54e2269b5da8db333/models/gpt2/gpt2.py#L179
|
|
|
618 |
# Shift so that tokens < n predict n
|
619 |
shift_logits = lm_logits[..., :-1, :].contiguous()
|
620 |
shift_labels = labels[..., 1:].contiguous()
|
|
|
621 |
loss_fct = CrossEntropyLoss()
|
622 |
loss = loss_fct(
|
623 |
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
|
624 |
)
|
|
|
625 |
loss = loss.to(hidden_states.dtype)
|
626 |
|
627 |
if not return_dict:
|