pranjalchitale commited on
Commit
aba31e1
·
verified ·
1 Parent(s): ca1b3a6

Fixes Tie Weights.

Browse files
Files changed (1) hide show
  1. modeling_indictrans.py +8 -7
modeling_indictrans.py CHANGED
@@ -1644,7 +1644,7 @@ class IndicTransModel(IndicTransPreTrainedModel):
1644
  # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ForConditionalGeneration->IndicTrans
1645
  class IndicTransForConditionalGeneration(IndicTransPreTrainedModel, GenerationMixin):
1646
  base_model_prefix = "model"
1647
- _tied_weights_keys = None
1648
  _label_smoothing = 0.0
1649
 
1650
  def __init__(self, config: IndicTransConfig):
@@ -1654,19 +1654,20 @@ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel, GenerationMi
1654
  config.decoder_embed_dim, config.decoder_vocab_size, bias=False
1655
  )
1656
 
1657
- if config.share_decoder_input_output_embed:
1658
- self.lm_head.weight = self.model.decoder.embed_tokens.weight
1659
-
1660
  self.post_init()
1661
 
1662
  def tie_weights(self):
1663
- pass
 
1664
 
1665
  def get_encoder(self):
1666
- return self.model.get_encoder()
1667
 
1668
  def get_decoder(self):
1669
- return self.model.get_decoder()
 
 
 
1670
 
1671
  def get_output_embeddings(self):
1672
  return self.lm_head
 
1644
  # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ForConditionalGeneration->IndicTrans
1645
  class IndicTransForConditionalGeneration(IndicTransPreTrainedModel, GenerationMixin):
1646
  base_model_prefix = "model"
1647
+ _tied_weights_keys = ["decoder.embed_tokens.weight", "lm_head.weight"]
1648
  _label_smoothing = 0.0
1649
 
1650
  def __init__(self, config: IndicTransConfig):
 
1654
  config.decoder_embed_dim, config.decoder_vocab_size, bias=False
1655
  )
1656
 
 
 
 
1657
  self.post_init()
1658
 
1659
  def tie_weights(self):
1660
+ if self.config.share_decoder_input_output_embed:
1661
+ self._tie_or_clone_weights(self.decoder.embed_tokens, self.lm_head)
1662
 
1663
  def get_encoder(self):
1664
+ return self.model.encoder
1665
 
1666
  def get_decoder(self):
1667
+ return self.model.decoder
1668
+
1669
+ def get_input_embeddings(self):
1670
+ return self.model.encoder.embed_tokens
1671
 
1672
  def get_output_embeddings(self):
1673
  return self.lm_head