File size: 5,162 Bytes
7338efe 82b218d 7338efe 85a027c 7338efe 85a027c 7338efe 85a027c 7338efe 85a027c 7338efe 77e9390 7338efe 85a027c 7338efe 77e9390 85a027c 7338efe 85a027c 7338efe 1cdce9a a87275b 7338efe d0632c3 7338efe 1cdce9a 7338efe 9359333 7338efe 9359333 7338efe 9359333 7338efe 9359333 7338efe 9359333 7338efe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import torch
from torch import nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput
from typing import Optional
from .configuration_minGRU import MinGRUConfig
from minGRU_pytorch.minGRU import minGRU
class MinGRUWrapped(nn.Module):
def __init__(self, min_gru_model):
super().__init__()
self.min_gru_model = min_gru_model
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def forward(self, *args, **kwargs):
args = [arg.to(self.device) if isinstance(arg, torch.Tensor) else arg for arg in args]
kwargs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
return self.min_gru_model(*args, **kwargs)
def to(self, device):
self.device = device
self.min_gru_model.to(device)
return self
class MinGRUPreTrainedModel(PreTrainedModel):
config_class = MinGRUConfig
base_model_prefix = "model"
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
for name, param in module.named_parameters():
if torch.isnan(param).any():
print(f"NaN detected in parameter {name}. Replacing with a safe number.")
param.data = torch.nan_to_num(param.data, nan=1e-6)
class MinGRUForSequenceClassification(PreTrainedModel):
config_class = MinGRUConfig
base_model_prefix = "model"
def __init__(self, config: MinGRUConfig):
super().__init__(config)
self.embedding = nn.Embedding(config.vocab_size, config.d_model)
raw_min_gru = minGRU(
dim=config.d_model,
expansion_factor=config.ff_mult
)
self.model = MinGRUWrapped(raw_min_gru)
self.classifier = nn.Sequential(
nn.Dropout(config.hidden_dropout_prob),
nn.Linear(config.d_model, config.num_labels)
)
self.post_init()
def forward(
self,
input_ids: torch.LongTensor,
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = True,
**kwargs
):
embeddings = self.embedding(input_ids)
logits = self.model(embeddings)
pooled_output = logits.mean(dim=1)
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
if not return_dict:
return (loss, logits) if loss is not None else (logits,)
return SequenceClassifierOutput(
loss=loss,
logits=logits,
)
"""
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
for name, param in model.named_parameters():
if name in ['embedding.weight', 'model.min_gru_model.to_hidden_and_gate.weight', 'model.min_gru_model.to_out.weight']:
if param is None or torch.isnan(param).any() or torch.isinf(param).any():
nn.init.xavier_normal_(param) # Başlatma işlemi
print(f"Initialized parameter {name} manually.")
return model
"""
def save_pretrained(self, save_directory, safe_serialization: Optional[bool] = True, **kwargs):
"""
Save the model and configuration to a directory.
Args:
save_directory (str): Directory to save the model.
safe_serialization (bool, optional): Whether to use safe serialization. Defaults to True.
kwargs: Additional arguments like max_shard_size (ignored in this implementation).
"""
import os
os.makedirs(save_directory, exist_ok=True)
if safe_serialization:
print("Saving with safe serialization.")
state_dict = {}
for name, param in self.model.min_gru_model.named_parameters():
state_dict[f"model.{name}"] = param
for name, param in self.classifier.named_parameters():
state_dict[f"classifier.{name}"] = param
state_dict['config'] = self.config.__dict__
torch.save(state_dict, os.path.join(save_directory, "pytorch_model.bin"))
self.config.save_pretrained(save_directory)
else:
print("Saving without safe serialization.")
super().save_pretrained(save_directory)
|