tommymarto commited on
Commit
7762514
·
1 Parent(s): 64cc94a

added studentbert config and modeling files

Browse files
Files changed (3) hide show
  1. config.json +9 -5
  2. configuration_mcqbert.py +10 -0
  3. modeling_mcqbert.py +46 -0
config.json CHANGED
@@ -1,22 +1,26 @@
1
  {
2
- "_name_or_path": "tommymarto/LernnaviBERT_mcqbert1_correct_answers_4096",
3
- "architectures": [
4
- "MCQBert1"
5
- ],
 
6
  "attention_probs_dropout_prob": 0.1,
7
  "classifier_dropout": null,
 
8
  "hidden_act": "gelu",
9
  "hidden_dropout_prob": 0.1,
10
  "hidden_size": 768,
11
  "initializer_range": 0.02,
 
12
  "intermediate_size": 3072,
13
  "layer_norm_eps": 1e-12,
14
  "max_position_embeddings": 512,
15
- "model_type": "bert",
16
  "num_attention_heads": 12,
17
  "num_hidden_layers": 12,
18
  "pad_token_id": 0,
19
  "position_embedding_type": "absolute",
 
20
  "torch_dtype": "float32",
21
  "transformers_version": "4.37.2",
22
  "type_vocab_size": 2,
 
1
  {
2
+ "_name_or_path": "epfl-ml4ed/MCQStudentBertCat",
3
+ "auto_map": {
4
+ "AutoConfig": "configuration_mcqbert.MCQBertConfig",
5
+ "AutoModel": "modeling_mcqbert.MCQStudentBert"
6
+ },
7
  "attention_probs_dropout_prob": 0.1,
8
  "classifier_dropout": null,
9
+ "cls_hidden_size": 256,
10
  "hidden_act": "gelu",
11
  "hidden_dropout_prob": 0.1,
12
  "hidden_size": 768,
13
  "initializer_range": 0.02,
14
+ "integration_strategy": "cat",
15
  "intermediate_size": 3072,
16
  "layer_norm_eps": 1e-12,
17
  "max_position_embeddings": 512,
18
+ "model_type": "mcqbert",
19
  "num_attention_heads": 12,
20
  "num_hidden_layers": 12,
21
  "pad_token_id": 0,
22
  "position_embedding_type": "absolute",
23
+ "student_embedding_size": 4096,
24
  "torch_dtype": "float32",
25
  "transformers_version": "4.37.2",
26
  "type_vocab_size": 2,
configuration_mcqbert.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertConfig
2
+
3
+ class MCQBertConfig(BertConfig):
4
+ model_type = "mcqbert"
5
+
6
+ def __init__(self, integration_strategy=None, student_embedding_size=4096, cls_hidden_size=256, **kwargs):
7
+ super().__init__(**kwargs)
8
+ self.integration_strategy = integration_strategy
9
+ self.student_embedding_size = student_embedding_size
10
+ self.cls_hidden_size = cls_hidden_size
modeling_mcqbert.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertModel
2
+ import torch
3
+
4
+ from .configuration_mcqbert import MCQBertConfig
5
+
6
+ class MCQStudentBert(BertModel):
7
+ def __init__(self, config: MCQBertConfig):
8
+ super().__init__(config)
9
+
10
+ if config.integration_strategy is not None:
11
+ self.student_embedding_layer = torch.nn.Linear(config.student_embedding_size, config.hidden_size)
12
+
13
+ cls_input_dim_multiplier = 2 if config.integration_strategy == "cat" else 1
14
+ cls_input_dim = self.config.hidden_size * cls_input_dim_multiplier
15
+
16
+ self.classifier = torch.nn.Sequential(
17
+ torch.nn.Linear(cls_input_dim, config.cls_hidden_size),
18
+ torch.nn.ReLU(),
19
+ torch.nn.Linear(config.cls_hidden_size, 1)
20
+ )
21
+
22
+ def forward(self, input_ids, student_embeddings=None):
23
+ if self.config.integration_strategy is None:
24
+ # don't consider embeddings is no integration strategy (MCQBert)
25
+ student_embeddings = torch.zeros(self.config.student_embedding_layer)
26
+
27
+ input_embeddings = self.embeddings(input_ids)
28
+ combined_embeddings = input_embeddings + self.student_embedding_layer(student_embeddings).unsqueeze(1).repeat(1, input_embeddings.size(1), 1)
29
+ output = super().forward(inputs_embeds = combined_embeddings)
30
+ return self.classifier(output.last_hidden_state[:, 0, :])
31
+
32
+ elif self.config.integration_strategy == "cat":
33
+ # MCQStudentBertCat
34
+ output = super().forward(input_ids)
35
+ output_with_student_embedding = torch.cat((output.last_hidden_state[:, 0, :], self.student_embedding_layer(student_embeddings)), dim = 1)
36
+ return self.classifier(output_with_student_embedding)
37
+
38
+ elif self.config.integration_strategy == "sum":
39
+ # MCQStudentBertSum
40
+ input_embeddings = self.embeddings(input_ids)
41
+ combined_embeddings = input_embeddings + self.student_embedding_layer(student_embeddings).unsqueeze(1).repeat(1, input_embeddings.size(1), 1)
42
+ output = super().forward(inputs_embeds = combined_embeddings)
43
+ return self.classifier(output.last_hidden_state[:, 0, :])
44
+
45
+ else:
46
+ raise ValueError(f"{self.config.integration_strategy} is not a known integration_strategy")