Change `wte` to use shared embedding
Browse filesAs title, this will help if, when we need to wrap `self.wte` with FSDP
- modeling_mpt.py +3 -2
modeling_mpt.py
CHANGED
@@ -12,6 +12,7 @@ from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokeniz
|
|
12 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
13 |
from .attention import attn_bias_shape, build_attn_bias
|
14 |
from .blocks import MPTBlock
|
|
|
15 |
from .norm import NORM_CLASS_REGISTRY
|
16 |
from .configuration_mpt import MPTConfig
|
17 |
from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
|
@@ -44,7 +45,7 @@ class MPTModel(MPTPreTrainedModel):
|
|
44 |
raise NotImplementedError(f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).')
|
45 |
norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
|
46 |
self.embedding_fraction = config.embedding_fraction
|
47 |
-
self.wte =
|
48 |
if not self.alibi:
|
49 |
self.wpe = nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device)
|
50 |
self.emb_drop = nn.Dropout(config.emb_pdrop)
|
@@ -252,7 +253,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
|
|
252 |
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
253 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
254 |
outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
|
255 |
-
logits =
|
256 |
if self.logit_scale is not None:
|
257 |
if self.logit_scale == 0:
|
258 |
warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
|
|
|
12 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
13 |
from .attention import attn_bias_shape, build_attn_bias
|
14 |
from .blocks import MPTBlock
|
15 |
+
from .custom_embedding import SharedEmbedding
|
16 |
from .norm import NORM_CLASS_REGISTRY
|
17 |
from .configuration_mpt import MPTConfig
|
18 |
from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
|
|
|
45 |
raise NotImplementedError(f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).')
|
46 |
norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
|
47 |
self.embedding_fraction = config.embedding_fraction
|
48 |
+
self.wte = SharedEmbedding(config.vocab_size, config.d_model, device=config.init_device)
|
49 |
if not self.alibi:
|
50 |
self.wpe = nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device)
|
51 |
self.emb_drop = nn.Dropout(config.emb_pdrop)
|
|
|
253 |
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
254 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
255 |
outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
|
256 |
+
logits = self.transformer.wte(outputs.last_hidden_state.to(self.transformer.wte.weight.device), True)
|
257 |
if self.logit_scale is not None:
|
258 |
if self.logit_scale == 0:
|
259 |
warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
|