zhihan1996
commited on
Commit
·
6041066
1
Parent(s):
69b2c8f
Update bert_layers.py
Browse files- bert_layers.py +5 -68
bert_layers.py
CHANGED
@@ -698,38 +698,6 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|
698 |
# Initialize weights and apply final processing
|
699 |
self.post_init()
|
700 |
|
701 |
-
@classmethod
|
702 |
-
def from_composer(cls,
|
703 |
-
pretrained_checkpoint,
|
704 |
-
state_dict=None,
|
705 |
-
cache_dir=None,
|
706 |
-
from_tf=False,
|
707 |
-
config=None,
|
708 |
-
*inputs,
|
709 |
-
**kwargs):
|
710 |
-
"""Load from pre-trained."""
|
711 |
-
model = cls(config, *inputs, **kwargs)
|
712 |
-
if from_tf:
|
713 |
-
raise ValueError(
|
714 |
-
'Mosaic BERT does not support loading TensorFlow weights.')
|
715 |
-
|
716 |
-
state_dict = torch.load(pretrained_checkpoint)
|
717 |
-
# If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix
|
718 |
-
consume_prefix_in_state_dict_if_present(state_dict, prefix='model.')
|
719 |
-
missing_keys, unexpected_keys = model.load_state_dict(state_dict,
|
720 |
-
strict=False)
|
721 |
-
|
722 |
-
if len(missing_keys) > 0:
|
723 |
-
logger.warning(
|
724 |
-
f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}"
|
725 |
-
)
|
726 |
-
if len(unexpected_keys) > 0:
|
727 |
-
logger.warning(
|
728 |
-
f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}"
|
729 |
-
)
|
730 |
-
|
731 |
-
return model
|
732 |
-
|
733 |
def get_output_embeddings(self):
|
734 |
return self.cls.predictions.decoder
|
735 |
|
@@ -786,7 +754,7 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|
786 |
return_dict=return_dict,
|
787 |
masked_tokens_mask=masked_tokens_mask,
|
788 |
)
|
789 |
-
|
790 |
sequence_output = outputs[0]
|
791 |
prediction_scores = self.cls(sequence_output)
|
792 |
|
@@ -813,8 +781,8 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|
813 |
return MaskedLMOutput(
|
814 |
loss=loss,
|
815 |
logits=prediction_scores,
|
816 |
-
hidden_states=outputs
|
817 |
-
attentions=
|
818 |
)
|
819 |
|
820 |
def prepare_inputs_for_generation(self, input_ids: torch.Tensor,
|
@@ -868,37 +836,6 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
|
868 |
# Initialize weights and apply final processing
|
869 |
self.post_init()
|
870 |
|
871 |
-
@classmethod
|
872 |
-
def from_composer(cls,
|
873 |
-
pretrained_checkpoint,
|
874 |
-
state_dict=None,
|
875 |
-
cache_dir=None,
|
876 |
-
from_tf=False,
|
877 |
-
config=None,
|
878 |
-
*inputs,
|
879 |
-
**kwargs):
|
880 |
-
"""Load from pre-trained."""
|
881 |
-
model = cls(config, *inputs, **kwargs)
|
882 |
-
if from_tf:
|
883 |
-
raise ValueError(
|
884 |
-
'Mosaic BERT does not support loading TensorFlow weights.')
|
885 |
-
|
886 |
-
state_dict = torch.load(pretrained_checkpoint)
|
887 |
-
# If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix
|
888 |
-
consume_prefix_in_state_dict_if_present(state_dict, prefix='model.')
|
889 |
-
missing_keys, unexpected_keys = model.load_state_dict(state_dict,
|
890 |
-
strict=False)
|
891 |
-
|
892 |
-
if len(missing_keys) > 0:
|
893 |
-
logger.warning(
|
894 |
-
f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}"
|
895 |
-
)
|
896 |
-
if len(unexpected_keys) > 0:
|
897 |
-
logger.warning(
|
898 |
-
f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}"
|
899 |
-
)
|
900 |
-
|
901 |
-
return model
|
902 |
|
903 |
def forward(
|
904 |
self,
|
@@ -972,7 +909,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
|
972 |
return SequenceClassifierOutput(
|
973 |
loss=loss,
|
974 |
logits=logits,
|
975 |
-
hidden_states=outputs
|
976 |
-
attentions=
|
977 |
)
|
978 |
|
|
|
698 |
# Initialize weights and apply final processing
|
699 |
self.post_init()
|
700 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
701 |
def get_output_embeddings(self):
|
702 |
return self.cls.predictions.decoder
|
703 |
|
|
|
754 |
return_dict=return_dict,
|
755 |
masked_tokens_mask=masked_tokens_mask,
|
756 |
)
|
757 |
+
|
758 |
sequence_output = outputs[0]
|
759 |
prediction_scores = self.cls(sequence_output)
|
760 |
|
|
|
781 |
return MaskedLMOutput(
|
782 |
loss=loss,
|
783 |
logits=prediction_scores,
|
784 |
+
hidden_states=outputs[0],
|
785 |
+
attentions=None,
|
786 |
)
|
787 |
|
788 |
def prepare_inputs_for_generation(self, input_ids: torch.Tensor,
|
|
|
836 |
# Initialize weights and apply final processing
|
837 |
self.post_init()
|
838 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
839 |
|
840 |
def forward(
|
841 |
self,
|
|
|
909 |
return SequenceClassifierOutput(
|
910 |
loss=loss,
|
911 |
logits=logits,
|
912 |
+
hidden_states=outputs[0],
|
913 |
+
attentions=None,
|
914 |
)
|
915 |
|