pranjalchitale
commited on
Fixes Tie Weights.
Browse files- modeling_indictrans.py +8 -7
modeling_indictrans.py
CHANGED
@@ -1644,7 +1644,7 @@ class IndicTransModel(IndicTransPreTrainedModel):
|
|
1644 |
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ForConditionalGeneration->IndicTrans
|
1645 |
class IndicTransForConditionalGeneration(IndicTransPreTrainedModel, GenerationMixin):
|
1646 |
base_model_prefix = "model"
|
1647 |
-
_tied_weights_keys =
|
1648 |
_label_smoothing = 0.0
|
1649 |
|
1650 |
def __init__(self, config: IndicTransConfig):
|
@@ -1654,19 +1654,20 @@ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel, GenerationMi
|
|
1654 |
config.decoder_embed_dim, config.decoder_vocab_size, bias=False
|
1655 |
)
|
1656 |
|
1657 |
-
if config.share_decoder_input_output_embed:
|
1658 |
-
self.lm_head.weight = self.model.decoder.embed_tokens.weight
|
1659 |
-
|
1660 |
self.post_init()
|
1661 |
|
1662 |
def tie_weights(self):
|
1663 |
-
|
|
|
1664 |
|
1665 |
def get_encoder(self):
|
1666 |
-
return self.model.
|
1667 |
|
1668 |
def get_decoder(self):
|
1669 |
-
return self.model.
|
|
|
|
|
|
|
1670 |
|
1671 |
def get_output_embeddings(self):
|
1672 |
return self.lm_head
|
|
|
1644 |
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ForConditionalGeneration->IndicTrans
|
1645 |
class IndicTransForConditionalGeneration(IndicTransPreTrainedModel, GenerationMixin):
|
1646 |
base_model_prefix = "model"
|
1647 |
+
_tied_weights_keys = ["decoder.embed_tokens.weight", "lm_head.weight"]
|
1648 |
_label_smoothing = 0.0
|
1649 |
|
1650 |
def __init__(self, config: IndicTransConfig):
|
|
|
1654 |
config.decoder_embed_dim, config.decoder_vocab_size, bias=False
|
1655 |
)
|
1656 |
|
|
|
|
|
|
|
1657 |
self.post_init()
|
1658 |
|
1659 |
def tie_weights(self):
|
1660 |
+
if self.config.share_decoder_input_output_embed:
|
1661 |
+
self._tie_or_clone_weights(self.decoder.embed_tokens, self.lm_head)
|
1662 |
|
1663 |
def get_encoder(self):
|
1664 |
+
return self.model.encoder
|
1665 |
|
1666 |
def get_decoder(self):
|
1667 |
+
return self.model.decoder
|
1668 |
+
|
1669 |
+
def get_input_embeddings(self):
|
1670 |
+
return self.model.encoder.embed_tokens
|
1671 |
|
1672 |
def get_output_embeddings(self):
|
1673 |
return self.lm_head
|