pranjalchitale
commited on
Commit
·
939611a
1
Parent(s):
aba31e1
Fixes TieWeights
Browse files- modeling_indictrans.py +2 -2
modeling_indictrans.py
CHANGED
@@ -1657,8 +1657,8 @@ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel, GenerationMi
|
|
1657 |
self.post_init()
|
1658 |
|
1659 |
def tie_weights(self):
|
1660 |
-
|
1661 |
-
|
1662 |
|
1663 |
def get_encoder(self):
|
1664 |
return self.model.encoder
|
|
|
1657 |
self.post_init()
|
1658 |
|
1659 |
def tie_weights(self):
|
1660 |
+
if self.config.share_decoder_input_output_embed:
|
1661 |
+
self._tie_or_clone_weights(self.model.decoder.embed_tokens, self.lm_head)
|
1662 |
|
1663 |
def get_encoder(self):
|
1664 |
return self.model.encoder
|