tcapelle commited on
Commit
67b4c35
·
verified ·
1 Parent(s): 27aaa7f

Upload model

Browse files
config.json CHANGED
@@ -4,6 +4,10 @@
4
  "MultiHeadDebertaForSequenceClassificationModel"
5
  ],
6
  "attention_probs_dropout_prob": 0.1,
 
 
 
 
7
  "hidden_act": "gelu",
8
  "hidden_dropout_prob": 0.1,
9
  "hidden_size": 768,
 
4
  "MultiHeadDebertaForSequenceClassificationModel"
5
  ],
6
  "attention_probs_dropout_prob": 0.1,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_deberta_multi.MultiHeadDebertaV2Config",
9
+ "AutoModelForSequenceClassification": "modelling_deberta_multi.MultiHeadDebertaForSequenceClassificationModel"
10
+ },
11
  "hidden_act": "gelu",
12
  "hidden_dropout_prob": 0.1,
13
  "hidden_size": 768,
configuration_deberta_multi.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from transformers import DebertaV2Config
2
+
3
+ class MultiHeadDebertaV2Config(DebertaV2Config):
4
+ model_type = "multi-head-deberta-for-sequence-classification"
5
+ def __init__(self, num_heads=5, **kwargs):
6
+ self.num_heads = num_heads
7
+ super().__init__(**kwargs)
modelling_deberta_multi.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, Tensor
3
+ from typing import Optional
4
+ from transformers import DebertaV2PreTrainedModel, DebertaV2Model
5
+ from .configuration_deberta_multi import MultiHeadDebertaV2Config
6
+
7
+ class MultiHeadDebertaForSequenceClassificationModel(DebertaV2PreTrainedModel):
8
+
9
+ config_class = MultiHeadDebertaV2Config
10
+ def __init__(self, config): # type: ignore
11
+ super().__init__(config)
12
+ self.deberta = DebertaV2Model(config)
13
+ self.heads = nn.ModuleList(
14
+ [nn.Linear(config.hidden_size, 4) for _ in range(config.num_heads)]
15
+ )
16
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
17
+ self.post_init()
18
+
19
+ def forward(
20
+ self,
21
+ input_ids: Optional["Tensor"] = None,
22
+ attention_mask: Optional["Tensor"] = None,
23
+ ) -> "Tensor":
24
+ outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask)
25
+ sequence_output = outputs[0]
26
+ logits_list = [
27
+ head(self.dropout(sequence_output[:, 0, :])) for head in self.heads
28
+ ]
29
+ logits = torch.stack(logits_list, dim=1)
30
+ return logits