daking bcui19 commited on
Commit
fe31052
·
1 Parent(s): 45aab4f

Change `wte` to use shared embedding (#43)

Browse files

- Change `wte` to use shared embedding (5fdc9e61729edbd287d220ff11abe5f0fadcc805)


Co-authored-by: Brandon Cui <[email protected]>

Files changed (1) hide show
  1. modeling_mpt.py +3 -2
modeling_mpt.py CHANGED
@@ -12,6 +12,7 @@ from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokeniz
12
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
13
  from .attention import attn_bias_shape, build_attn_bias
14
  from .blocks import MPTBlock
 
15
  from .norm import NORM_CLASS_REGISTRY
16
  from .configuration_mpt import MPTConfig
17
  from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
@@ -44,7 +45,7 @@ class MPTModel(MPTPreTrainedModel):
44
  raise NotImplementedError(f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).')
45
  norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
46
  self.embedding_fraction = config.embedding_fraction
47
- self.wte = nn.Embedding(config.vocab_size, config.d_model, device=config.init_device)
48
  if not self.alibi:
49
  self.wpe = nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device)
50
  self.emb_drop = nn.Dropout(config.emb_pdrop)
@@ -252,7 +253,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
252
  return_dict = return_dict if return_dict is not None else self.config.return_dict
253
  use_cache = use_cache if use_cache is not None else self.config.use_cache
254
  outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
255
- logits = F.linear(outputs.last_hidden_state.to(self.transformer.wte.weight.device), self.transformer.wte.weight)
256
  if self.logit_scale is not None:
257
  if self.logit_scale == 0:
258
  warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
 
12
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
13
  from .attention import attn_bias_shape, build_attn_bias
14
  from .blocks import MPTBlock
15
+ from .custom_embedding import SharedEmbedding
16
  from .norm import NORM_CLASS_REGISTRY
17
  from .configuration_mpt import MPTConfig
18
  from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
 
45
  raise NotImplementedError(f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).')
46
  norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
47
  self.embedding_fraction = config.embedding_fraction
48
+ self.wte = SharedEmbedding(config.vocab_size, config.d_model, device=config.init_device)
49
  if not self.alibi:
50
  self.wpe = nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device)
51
  self.emb_drop = nn.Dropout(config.emb_pdrop)
 
253
  return_dict = return_dict if return_dict is not None else self.config.return_dict
254
  use_cache = use_cache if use_cache is not None else self.config.use_cache
255
  outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
256
+ logits = self.transformer.wte(outputs.last_hidden_state.to(self.transformer.wte.weight.device), True)
257
  if self.logit_scale is not None:
258
  if self.logit_scale == 0:
259
  warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')