suayptalha commited on
Commit
c19f25f
·
verified ·
1 Parent(s): 0ff2ef1

Create modeling_minGRU.py

Browse files
Files changed (1) hide show
  1. modeling_minGRU.py +140 -0
modeling_minGRU.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from transformers import PreTrainedModel
4
+ from transformers.modeling_outputs import SequenceClassifierOutput
5
+ from typing import Optional
6
+ from .configuration_minGRU import MinGRUConfig
7
+ from minGRU_pytorch.minGRU import minGRU
8
+
9
+ class MinGRUWrapped(nn.Module):
10
+ def __init__(self, min_gru_model):
11
+ super().__init__()
12
+ self.min_gru_model = min_gru_model
13
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+
15
+ def forward(self, *args, **kwargs):
16
+ args = [arg.to(self.device) if isinstance(arg, torch.Tensor) else arg for arg in args]
17
+ kwargs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
18
+ return self.min_gru_model(*args, **kwargs)
19
+
20
+ def to(self, device):
21
+ self.device = device
22
+ self.min_gru_model.to(device)
23
+ return self
24
+
25
+ class MinGRUPreTrainedModel(PreTrainedModel):
26
+ config_class = MinGRUConfig
27
+ base_model_prefix = "model"
28
+
29
+ def _init_weights(self, module):
30
+ std = self.config.initializer_range
31
+ if isinstance(module, nn.Linear):
32
+ module.weight.data.normal_(mean=0.0, std=std)
33
+ if module.bias is not None:
34
+ module.bias.data.zero_()
35
+ elif isinstance(module, nn.Embedding):
36
+ module.weight.data.normal_(mean=0.0, std=std)
37
+ if module.padding_idx is not None:
38
+ module.weight.data[module.padding_idx].zero_()
39
+ elif isinstance(module, nn.LayerNorm):
40
+ module.bias.data.zero_()
41
+ module.weight.data.fill_(1.0)
42
+
43
+ for name, param in module.named_parameters():
44
+ if torch.isnan(param).any():
45
+ print(f"NaN detected in parameter {name}. Replacing with a safe number.")
46
+ param.data = torch.nan_to_num(param.data, nan=1e-6)
47
+
48
+ class MinGRUForSequenceClassification(PreTrainedModel):
49
+ config_class = MinGRUConfig
50
+ base_model_prefix = "model"
51
+
52
+ def __init__(self, config: MinGRUConfig):
53
+ super().__init__(config)
54
+
55
+ self.embedding = nn.Embedding(config.vocab_size, config.d_model)
56
+
57
+ raw_min_gru = minGRU(
58
+ dim=config.d_model,
59
+ expansion_factor=config.ff_mult
60
+ )
61
+ self.model = MinGRUWrapped(raw_min_gru)
62
+
63
+ # Final linear layer for classification
64
+ self.classifier = nn.Linear(config.d_model, config.num_labels)
65
+
66
+ self.post_init()
67
+
68
+ def forward(
69
+ self,
70
+ input_ids: torch.LongTensor,
71
+ labels: Optional[torch.LongTensor] = None,
72
+ return_dict: Optional[bool] = True,
73
+ **kwargs
74
+ ):
75
+ embeddings = self.embedding(input_ids)
76
+
77
+ logits = self.model(embeddings)
78
+
79
+ pooled_output = logits.mean(dim=1)
80
+
81
+ logits = self.classifier(pooled_output) # No need for additional layers here
82
+
83
+ loss = None
84
+ if labels is not None:
85
+ loss_fct = nn.CrossEntropyLoss()
86
+ loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
87
+
88
+ if not return_dict:
89
+ return (loss, logits) if loss is not None else (logits,)
90
+
91
+ return SequenceClassifierOutput(
92
+ loss=loss,
93
+ logits=logits,
94
+ )
95
+
96
+ @classmethod
97
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
98
+ """
99
+ Load model from a pretrained checkpoint.
100
+ """
101
+ model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
102
+
103
+ for name, param in model.named_parameters():
104
+ if name in ['embedding.weight', 'model.min_gru_model.to_hidden_and_gate.weight', 'model.min_gru_model.to_out.weight']:
105
+ if param is None or torch.isnan(param).any() or torch.isinf(param).any():
106
+ nn.init.xavier_normal_(param) # Başlatma işlemi
107
+ print(f"Initialized parameter {name} manually.")
108
+
109
+ return model
110
+
111
+ def save_pretrained(self, save_directory, safe_serialization: Optional[bool] = True, **kwargs):
112
+ """
113
+ Save the model and configuration to a directory.
114
+
115
+ Args:
116
+ save_directory (str): Directory to save the model.
117
+ safe_serialization (bool, optional): Whether to use safe serialization. Defaults to True.
118
+ kwargs: Additional arguments like max_shard_size (ignored in this implementation).
119
+ """
120
+ import os
121
+ os.makedirs(save_directory, exist_ok=True)
122
+
123
+ if safe_serialization:
124
+ print("Saving with safe serialization.")
125
+
126
+ state_dict = {}
127
+
128
+ for name, param in self.model.min_gru_model.named_parameters():
129
+ state_dict[f"model.{name}"] = param
130
+
131
+ for name, param in self.classifier.named_parameters():
132
+ state_dict[f"classifier.{name}"] = param
133
+
134
+ state_dict['config'] = self.config.__dict__
135
+ torch.save(state_dict, os.path.join(save_directory, "pytorch_model.bin"))
136
+
137
+ self.config.save_pretrained(save_directory)
138
+ else:
139
+ print("Saving without safe serialization.")
140
+ super().save_pretrained(save_directory)