moontidef commited on
Commit
08469c4
·
1 Parent(s): 13c4251

feat: add support for sentence classifier

Browse files
config.json CHANGED
@@ -3,8 +3,12 @@
3
  "AutoConfig": "configuration_xlm_roberta.XLMRobertaFlashConfig",
4
  "AutoModel": "modeling_xlm_roberta.XLMRobertaModel",
5
  "AutoModelForPreTraining": "modeling_xlm_roberta.XLMRobertaForPreTraining",
6
- "AutoModelForMaskedLM": "modeling_xlm_roberta.XLMRobertaForMaskedLM"
 
7
  },
 
 
 
8
  "attention_probs_dropout_prob": 0.1,
9
  "bos_token_id": 0,
10
  "eos_token_id": 2,
 
3
  "AutoConfig": "configuration_xlm_roberta.XLMRobertaFlashConfig",
4
  "AutoModel": "modeling_xlm_roberta.XLMRobertaModel",
5
  "AutoModelForPreTraining": "modeling_xlm_roberta.XLMRobertaForPreTraining",
6
+ "AutoModelForMaskedLM": "modeling_xlm_roberta.XLMRobertaForMaskedLM",
7
+ "AutoModelForSequenceClassification":"modeling_xlm_roberta.XLMRobertaForSequenceClassification"
8
  },
9
+ "architectures": [
10
+ "XLMRobertaModel"
11
+ ],
12
  "attention_probs_dropout_prob": 0.1,
13
  "bos_token_id": 0,
14
  "eos_token_id": 2,
convert_roberta_weights_to_flash.py CHANGED
@@ -1,10 +1,11 @@
1
  import re
2
  from collections import OrderedDict
3
  from transformers import PretrainedConfig
4
- from transformers import XLMRobertaForMaskedLM
5
 
6
  from .configuration_xlm_roberta import XLMRobertaFlashConfig as BertConfig
7
- from .modeling_xlm_roberta import XLMRobertaForMaskedLM as BertModel
 
8
  import torch
9
 
10
  import click
@@ -137,14 +138,23 @@ def remap_state_dict(state_dict, config: PretrainedConfig):
137
 
138
  @click.command()
139
  @click.option('--model_name', default='FacebookAI/xlm-roberta-base', help='model name')
 
 
140
  @click.option('--output', default='converted_roberta_weights.bin', help='model name')
141
- def main(model_name, output):
142
- roberta_model = XLMRobertaForMaskedLM.from_pretrained(model_name)
 
 
 
 
143
  config = BertConfig.from_dict(roberta_model.config.to_dict())
144
  state_dict = roberta_model.state_dict()
145
  new_state_dict = remap_state_dict(state_dict, config)
146
-
147
- flash_model = BertModel(config)
 
 
 
148
 
149
  for k, v in flash_model.state_dict().items():
150
  if k not in new_state_dict:
 
1
  import re
2
  from collections import OrderedDict
3
  from transformers import PretrainedConfig
4
+ from transformers import XLMRobertaForMaskedLM, XLMRobertaForSequenceClassification
5
 
6
  from .configuration_xlm_roberta import XLMRobertaFlashConfig as BertConfig
7
+ from .modeling_xlm_roberta import XLMRobertaForMaskedLM as FlashXLMRobertaForMaskedLM
8
+ from .modeling_xlm_roberta import XLMRobertaForSequenceClassification as FlashXLMRobertaForSequenceClassification
9
  import torch
10
 
11
  import click
 
138
 
139
  @click.command()
140
  @click.option('--model_name', default='FacebookAI/xlm-roberta-base', help='model name')
141
+ @click.option('--revision', default='main', help='revision')
142
+ @click.option('--task', default='masked_lm', help='task')
143
  @click.option('--output', default='converted_roberta_weights.bin', help='model name')
144
+ def main(model_name, revision, task, output):
145
+
146
+ if task == 'masked_lm':
147
+ roberta_model = XLMRobertaForMaskedLM.from_pretrained(model_name, revision=revision)
148
+ elif task == 'sequence_classification':
149
+ roberta_model = XLMRobertaForSequenceClassification.from_pretrained(model_name, revision=revision,num_labels=1)
150
  config = BertConfig.from_dict(roberta_model.config.to_dict())
151
  state_dict = roberta_model.state_dict()
152
  new_state_dict = remap_state_dict(state_dict, config)
153
+
154
+ if task == 'masked_lm':
155
+ flash_model = FlashXLMRobertaForMaskedLM(config)
156
+ elif task == 'sequence_classification':
157
+ flash_model = FlashXLMRobertaForSequenceClassification(config)
158
 
159
  for k, v in flash_model.state_dict().items():
160
  if k not in new_state_dict:
modeling_xlm_roberta.py CHANGED
@@ -19,10 +19,11 @@ import torch
19
  import torch.nn as nn
20
  import torch.nn.functional as F
21
  import torch.utils.checkpoint
 
22
  from einops import rearrange
23
  from transformers import PretrainedConfig
24
  from transformers.modeling_utils import PreTrainedModel
25
- from transformers.modeling_outputs import MaskedLMOutput
26
  from transformers.models.xlm_roberta.modeling_xlm_roberta import XLMRobertaLMHead
27
 
28
  from transformers.models.bert.modeling_bert import (
@@ -1139,3 +1140,117 @@ def inv_remap_state_dict(state_dict, config: PretrainedConfig):
1139
  )
1140
 
1141
  return state_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  import torch.nn as nn
20
  import torch.nn.functional as F
21
  import torch.utils.checkpoint
22
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
23
  from einops import rearrange
24
  from transformers import PretrainedConfig
25
  from transformers.modeling_utils import PreTrainedModel
26
+ from transformers.modeling_outputs import MaskedLMOutput,SequenceClassifierOutput
27
  from transformers.models.xlm_roberta.modeling_xlm_roberta import XLMRobertaLMHead
28
 
29
  from transformers.models.bert.modeling_bert import (
 
1140
  )
1141
 
1142
  return state_dict
1143
+
1144
+
1145
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaClassificationHead with Roberta->XLMRoberta
1146
+ class XLMRobertaClassificationHead(nn.Module):
1147
+ """Head for sentence-level classification tasks."""
1148
+
1149
+ def __init__(self, config):
1150
+ super().__init__()
1151
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
1152
+ classifier_dropout = (
1153
+ config.classifier_dropout
1154
+ if config.classifier_dropout is not None
1155
+ else config.hidden_dropout_prob
1156
+ )
1157
+ self.dropout = nn.Dropout(classifier_dropout)
1158
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
1159
+
1160
+ def forward(self, features, **kwargs):
1161
+ x = features[:, 0, :] # take <s> token (equiv. to [CLS])
1162
+ x = self.dropout(x)
1163
+ x = self.dense(x)
1164
+ x = torch.tanh(x)
1165
+ x = self.dropout(x)
1166
+ x = self.out_proj(x)
1167
+ return x
1168
+
1169
+
1170
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaForSequenceClassification with Roberta->XLMRoberta, ROBERTA->XLM_ROBERTA
1171
+ class XLMRobertaForSequenceClassification(XLMRobertaPreTrainedModel):
1172
+ def __init__(self, config):
1173
+ super().__init__(config)
1174
+ self.num_labels = config.num_labels
1175
+ self.config = config
1176
+
1177
+ self.roberta = XLMRobertaModel(config, add_pooling_layer=False)
1178
+ self.classifier = XLMRobertaClassificationHead(config)
1179
+
1180
+ # Initialize weights and apply final processing
1181
+ self.post_init()
1182
+
1183
+ def forward(
1184
+ self,
1185
+ input_ids: Optional[torch.LongTensor] = None,
1186
+ attention_mask: Optional[torch.FloatTensor] = None,
1187
+ token_type_ids: Optional[torch.LongTensor] = None,
1188
+ position_ids: Optional[torch.LongTensor] = None,
1189
+ head_mask: Optional[torch.FloatTensor] = None,
1190
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1191
+ labels: Optional[torch.LongTensor] = None,
1192
+ output_attentions: Optional[bool] = None,
1193
+ output_hidden_states: Optional[bool] = None,
1194
+ return_dict: Optional[bool] = None,
1195
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
1196
+ r"""
1197
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1198
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1199
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1200
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1201
+ """
1202
+ return_dict = (
1203
+ return_dict if return_dict is not None else self.config.use_return_dict
1204
+ )
1205
+
1206
+ outputs = self.roberta(
1207
+ input_ids,
1208
+ attention_mask=attention_mask,
1209
+ token_type_ids=token_type_ids,
1210
+ position_ids=position_ids,
1211
+ head_mask=head_mask,
1212
+ inputs_embeds=inputs_embeds,
1213
+ output_attentions=output_attentions,
1214
+ output_hidden_states=output_hidden_states,
1215
+ return_dict=return_dict,
1216
+ )
1217
+ sequence_output = outputs[0]
1218
+ logits = self.classifier(sequence_output)
1219
+
1220
+ loss = None
1221
+ if labels is not None:
1222
+ # move labels to correct device to enable model parallelism
1223
+ labels = labels.to(logits.device)
1224
+ if self.config.problem_type is None:
1225
+ if self.num_labels == 1:
1226
+ self.config.problem_type = "regression"
1227
+ elif self.num_labels > 1 and (
1228
+ labels.dtype == torch.long or labels.dtype == torch.int
1229
+ ):
1230
+ self.config.problem_type = "single_label_classification"
1231
+ else:
1232
+ self.config.problem_type = "multi_label_classification"
1233
+
1234
+ if self.config.problem_type == "regression":
1235
+ loss_fct = MSELoss()
1236
+ if self.num_labels == 1:
1237
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1238
+ else:
1239
+ loss = loss_fct(logits, labels)
1240
+ elif self.config.problem_type == "single_label_classification":
1241
+ loss_fct = CrossEntropyLoss()
1242
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1243
+ elif self.config.problem_type == "multi_label_classification":
1244
+ loss_fct = BCEWithLogitsLoss()
1245
+ loss = loss_fct(logits, labels)
1246
+
1247
+ if not return_dict:
1248
+ output = (logits,) + outputs[2:]
1249
+ return ((loss,) + output) if loss is not None else output
1250
+
1251
+ return SequenceClassifierOutput(
1252
+ loss=loss,
1253
+ logits=logits,
1254
+ hidden_states=outputs.hidden_states,
1255
+ attentions=outputs.attentions,
1256
+ )